11import logging # noqa: I001
2- from typing import Any
32
43from a2a .client .auth .credentials import CredentialService
5- from a2a .client .middleware import ClientCallContext , ClientCallInterceptor
6- from a2a .types .a2a_pb2 import AgentCard
4+ from a2a .client .client import ClientCallContext
5+ from a2a .client .interceptors import (
6+ AfterArgs ,
7+ BeforeArgs ,
8+ ClientCallInterceptor ,
9+ )
710
811logger = logging .getLogger (__name__ )
912
@@ -17,79 +20,79 @@ class AuthInterceptor(ClientCallInterceptor):
1720 def __init__ (self , credential_service : CredentialService ):
1821 self ._credential_service = credential_service
1922
20- async def intercept (
21- self ,
22- method_name : str ,
23- request_payload : dict [str , Any ],
24- http_kwargs : dict [str , Any ],
25- agent_card : AgentCard | None ,
26- context : ClientCallContext | None ,
27- ) -> tuple [dict [str , Any ], dict [str , Any ]]:
23+ async def before (self , args : BeforeArgs ) -> None :
2824 """Applies authentication headers to the request if credentials are available."""
25+ agent_card = args .agent_card
26+
2927 # Proto3 repeated fields (security) and maps (security_schemes) do not track presence.
3028 # HasField() raises ValueError for them.
3129 # We check for truthiness to see if they are non-empty.
3230 if (
33- agent_card is None
34- or not agent_card .security_requirements
31+ not agent_card .security_requirements
3532 or not agent_card .security_schemes
3633 ):
37- return request_payload , http_kwargs
34+ return
3835
3936 for requirement in agent_card .security_requirements :
4037 for scheme_name in requirement .schemes :
4138 credential = await self ._credential_service .get_credentials (
42- scheme_name , context
39+ scheme_name , args . context
4340 )
4441 if credential and scheme_name in agent_card .security_schemes :
4542 scheme = agent_card .security_schemes .get (scheme_name )
4643 if not scheme :
4744 continue
4845
49- headers = http_kwargs .get ('headers' , {})
46+ if args .context is None :
47+ args .context = ClientCallContext ()
48+
49+ if args .context .service_parameters is None :
50+ args .context .service_parameters = {}
5051
5152 # HTTP Bearer authentication
5253 if (
5354 scheme .HasField ('http_auth_security_scheme' )
5455 and scheme .http_auth_security_scheme .scheme .lower ()
5556 == 'bearer'
5657 ):
57- headers ['Authorization' ] = f'Bearer { credential } '
58+ args .context .service_parameters ['Authorization' ] = (
59+ f'Bearer { credential } '
60+ )
5861 logger .debug (
5962 "Added Bearer token for scheme '%s'." ,
6063 scheme_name ,
6164 )
62- http_kwargs ['headers' ] = headers
63- return request_payload , http_kwargs
65+ return
6466
6567 # OAuth2 and OIDC schemes are implicitly Bearer
6668 if scheme .HasField (
6769 'oauth2_security_scheme'
6870 ) or scheme .HasField ('open_id_connect_security_scheme' ):
69- headers ['Authorization' ] = f'Bearer { credential } '
71+ args .context .service_parameters ['Authorization' ] = (
72+ f'Bearer { credential } '
73+ )
7074 logger .debug (
7175 "Added Bearer token for scheme '%s'." ,
7276 scheme_name ,
7377 )
74- http_kwargs ['headers' ] = headers
75- return request_payload , http_kwargs
78+ return
7679
7780 # API Key in Header
7881 if (
7982 scheme .HasField ('api_key_security_scheme' )
8083 and scheme .api_key_security_scheme .location .lower ()
8184 == 'header'
8285 ):
83- headers [ scheme . api_key_security_scheme . name ] = (
84- credential
85- )
86+ args . context . service_parameters [
87+ scheme . api_key_security_scheme . name
88+ ] = credential
8689 logger .debug (
8790 "Added API Key Header for scheme '%s'." ,
8891 scheme_name ,
8992 )
90- http_kwargs ['headers' ] = headers
91- return request_payload , http_kwargs
93+ return
9294
9395 # Note: Other cases like API keys in query/cookie are not handled and will be skipped.
9496
95- return request_payload , http_kwargs
97+ async def after (self , args : AfterArgs ) -> None :
98+ """Invoked after the method is executed."""
0 commit comments