Skip to content

Commit a910cbc

Browse files
feat(client): transport agnostic interceptors (#796)
# Description This PR refactors the client interceptors architecture, centralizing their execution within the BaseClient rather than delegating them to the underlying transport implementations. These interceptors allow to modify request before being sent to the server, and server responses before are sent back to the caller, with an early return mechanism. The Authentication interceptor is updated as well to store authentication values in the ServiceParameters class of the ClientCallContext. Fix: #757
1 parent 709b1ff commit a910cbc

25 files changed

+663
-300
lines changed

src/a2a/client/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,21 @@
99
)
1010
from a2a.client.base_client import BaseClient
1111
from a2a.client.card_resolver import A2ACardResolver
12-
from a2a.client.client import Client, ClientConfig, ClientEvent, Consumer
12+
from a2a.client.client import (
13+
Client,
14+
ClientCallContext,
15+
ClientConfig,
16+
ClientEvent,
17+
Consumer,
18+
)
1319
from a2a.client.client_factory import ClientFactory, minimal_agent_card
1420
from a2a.client.errors import (
1521
A2AClientError,
1622
A2AClientTimeoutError,
1723
AgentCardResolutionError,
1824
)
1925
from a2a.client.helpers import create_text_message_object
20-
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
26+
from a2a.client.interceptors import ClientCallInterceptor
2127

2228

2329
logger = logging.getLogger(__name__)

src/a2a/client/auth/credentials.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from abc import ABC, abstractmethod
22

3-
from a2a.client.middleware import ClientCallContext
3+
from a2a.client.client import ClientCallContext
44

55

66
class CredentialService(ABC):

src/a2a/client/auth/interceptor.py

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import logging # noqa: I001
2-
from typing import Any
32

43
from a2a.client.auth.credentials import CredentialService
5-
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
6-
from a2a.types.a2a_pb2 import AgentCard
4+
from a2a.client.client import ClientCallContext
5+
from a2a.client.interceptors import (
6+
AfterArgs,
7+
BeforeArgs,
8+
ClientCallInterceptor,
9+
)
710

811
logger = logging.getLogger(__name__)
912

@@ -17,79 +20,79 @@ class AuthInterceptor(ClientCallInterceptor):
1720
def __init__(self, credential_service: CredentialService):
1821
self._credential_service = credential_service
1922

20-
async def intercept(
21-
self,
22-
method_name: str,
23-
request_payload: dict[str, Any],
24-
http_kwargs: dict[str, Any],
25-
agent_card: AgentCard | None,
26-
context: ClientCallContext | None,
27-
) -> tuple[dict[str, Any], dict[str, Any]]:
23+
async def before(self, args: BeforeArgs) -> None:
2824
"""Applies authentication headers to the request if credentials are available."""
25+
agent_card = args.agent_card
26+
2927
# Proto3 repeated fields (security) and maps (security_schemes) do not track presence.
3028
# HasField() raises ValueError for them.
3129
# We check for truthiness to see if they are non-empty.
3230
if (
33-
agent_card is None
34-
or not agent_card.security_requirements
31+
not agent_card.security_requirements
3532
or not agent_card.security_schemes
3633
):
37-
return request_payload, http_kwargs
34+
return
3835

3936
for requirement in agent_card.security_requirements:
4037
for scheme_name in requirement.schemes:
4138
credential = await self._credential_service.get_credentials(
42-
scheme_name, context
39+
scheme_name, args.context
4340
)
4441
if credential and scheme_name in agent_card.security_schemes:
4542
scheme = agent_card.security_schemes.get(scheme_name)
4643
if not scheme:
4744
continue
4845

49-
headers = http_kwargs.get('headers', {})
46+
if args.context is None:
47+
args.context = ClientCallContext()
48+
49+
if args.context.service_parameters is None:
50+
args.context.service_parameters = {}
5051

5152
# HTTP Bearer authentication
5253
if (
5354
scheme.HasField('http_auth_security_scheme')
5455
and scheme.http_auth_security_scheme.scheme.lower()
5556
== 'bearer'
5657
):
57-
headers['Authorization'] = f'Bearer {credential}'
58+
args.context.service_parameters['Authorization'] = (
59+
f'Bearer {credential}'
60+
)
5861
logger.debug(
5962
"Added Bearer token for scheme '%s'.",
6063
scheme_name,
6164
)
62-
http_kwargs['headers'] = headers
63-
return request_payload, http_kwargs
65+
return
6466

6567
# OAuth2 and OIDC schemes are implicitly Bearer
6668
if scheme.HasField(
6769
'oauth2_security_scheme'
6870
) or scheme.HasField('open_id_connect_security_scheme'):
69-
headers['Authorization'] = f'Bearer {credential}'
71+
args.context.service_parameters['Authorization'] = (
72+
f'Bearer {credential}'
73+
)
7074
logger.debug(
7175
"Added Bearer token for scheme '%s'.",
7276
scheme_name,
7377
)
74-
http_kwargs['headers'] = headers
75-
return request_payload, http_kwargs
78+
return
7679

7780
# API Key in Header
7881
if (
7982
scheme.HasField('api_key_security_scheme')
8083
and scheme.api_key_security_scheme.location.lower()
8184
== 'header'
8285
):
83-
headers[scheme.api_key_security_scheme.name] = (
84-
credential
85-
)
86+
args.context.service_parameters[
87+
scheme.api_key_security_scheme.name
88+
] = credential
8689
logger.debug(
8790
"Added API Key Header for scheme '%s'.",
8891
scheme_name,
8992
)
90-
http_kwargs['headers'] = headers
91-
return request_payload, http_kwargs
93+
return
9294

9395
# Note: Other cases like API keys in query/cookie are not handled and will be skipped.
9496

95-
return request_payload, http_kwargs
97+
async def after(self, args: AfterArgs) -> None:
98+
"""Invoked after the method is executed."""

0 commit comments

Comments
 (0)