|
| 1 | +import contextlib |
1 | 2 | import json |
2 | 3 | import logging |
3 | 4 |
|
@@ -63,11 +64,14 @@ def __init__( |
63 | 64 | httpx_client: httpx.AsyncClient, |
64 | 65 | agent_card: AgentCard | None, |
65 | 66 | url: str, |
| 67 | + subscribe_method_override: str | None = None, |
66 | 68 | ): |
67 | 69 | """Initializes the CompatRestTransport.""" |
68 | 70 | self.url = url.removesuffix('/') |
69 | 71 | self.httpx_client = httpx_client |
70 | 72 | self.agent_card = agent_card |
| 73 | + self._subscribe_method_override = subscribe_method_override |
| 74 | + self._subscribe_auto_method_override = subscribe_method_override is None |
71 | 75 |
|
72 | 76 | async def send_message( |
73 | 77 | self, |
@@ -273,13 +277,41 @@ async def subscribe( |
273 | 277 | *, |
274 | 278 | context: ClientCallContext | None = None, |
275 | 279 | ) -> AsyncGenerator[StreamResponse]: |
276 | | - """Reconnects to get task updates.""" |
277 | | - async for event in self._send_stream_request( |
278 | | - 'GET', |
279 | | - f'/v1/tasks/{request.id}:subscribe', |
280 | | - context=context, |
281 | | - ): |
282 | | - yield event |
| 280 | + """Reconnects to get task updates. |
| 281 | +
|
| 282 | + This method implements backward compatibility logic for the subscribe |
| 283 | + endpoint. It first attempts to use POST, which is the official method |
| 284 | + for A2A subscribe endpoint. If the server returns 405 Method Not Allowed, |
| 285 | + it falls back to GET and remembers this preference for future calls |
| 286 | + on this transport instance. If both fail with 405, it will default back |
| 287 | + to POST for next calls but will not retry again. |
| 288 | + """ |
| 289 | + subscribe_method = self._subscribe_method_override or 'POST' |
| 290 | + try: |
| 291 | + async for event in self._send_stream_request( |
| 292 | + subscribe_method, |
| 293 | + f'/v1/tasks/{request.id}:subscribe', |
| 294 | + context=context, |
| 295 | + ): |
| 296 | + yield event |
| 297 | + except A2AClientError as e: |
| 298 | + # Check for 405 Method Not Allowed in the cause (httpx.HTTPStatusError) |
| 299 | + cause = e.__cause__ |
| 300 | + if ( |
| 301 | + isinstance(cause, httpx.HTTPStatusError) |
| 302 | + and cause.response.status_code == httpx.codes.METHOD_NOT_ALLOWED |
| 303 | + ): |
| 304 | + if self._subscribe_method_override: |
| 305 | + if self._subscribe_auto_method_override: |
| 306 | + self._subscribe_auto_method_override = False |
| 307 | + self._subscribe_method_override = 'POST' |
| 308 | + raise |
| 309 | + else: |
| 310 | + self._subscribe_method_override = 'GET' |
| 311 | + async for event in self.subscribe(request, context=context): |
| 312 | + yield event |
| 313 | + else: |
| 314 | + raise |
283 | 315 |
|
284 | 316 | async def get_extended_agent_card( |
285 | 317 | self, |
@@ -311,7 +343,14 @@ async def close(self) -> None: |
311 | 343 | def _handle_http_error(self, e: httpx.HTTPStatusError) -> NoReturn: |
312 | 344 | """Handles HTTP status errors and raises the appropriate A2AError.""" |
313 | 345 | try: |
314 | | - error_data = e.response.json() |
| 346 | + with contextlib.suppress(httpx.StreamClosed): |
| 347 | + e.response.read() |
| 348 | + |
| 349 | + try: |
| 350 | + error_data = e.response.json() |
| 351 | + except (json.JSONDecodeError, ValueError, httpx.ResponseNotRead): |
| 352 | + error_data = {} |
| 353 | + |
315 | 354 | error_type = error_data.get('type') |
316 | 355 | message = error_data.get('message', str(e)) |
317 | 356 |
|
|
0 commit comments