Skip to content

Commit a0827d0

Browse files
authored
fix: Use POST method for REST endpoint /tasks/{id}:subscribe (#843)
POST should be always use for /tasks/{id}:subscribe. Decisions for backward compatibility with invalid protocol implementations: 1.0 server: Accept both POST and GET 1.0 client: Always use POST 0.3 server: Accept both POST and GET 0.3 client: Try POST first, on HTTP 405 error retry with GET. Cache the retry state to ensure that there is at most one retry attempt per transport instance. Fixes #840
1 parent ea7d3ad commit a0827d0

8 files changed

Lines changed: 376 additions & 16 deletions

File tree

src/a2a/client/transports/rest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ async def subscribe(
258258
) -> AsyncGenerator[StreamResponse]:
259259
"""Reconnects to get task updates."""
260260
async for event in self._send_stream_request(
261-
'GET',
261+
'POST',
262262
f'/tasks/{request.id}:subscribe',
263263
request.tenant,
264264
context=context,

src/a2a/compat/v0_3/rest_adapter.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,10 @@ def routes(self) -> dict[tuple[str, str], Callable[[Request], Any]]:
163163
self._handle_streaming_request,
164164
self.handler.on_subscribe_to_task,
165165
),
166+
('/v1/tasks/{id}:subscribe', 'POST'): functools.partial(
167+
self._handle_streaming_request,
168+
self.handler.on_subscribe_to_task,
169+
),
166170
('/v1/tasks/{id}', 'GET'): functools.partial(
167171
self._handle_request, self.handler.on_get_task
168172
),

src/a2a/compat/v0_3/rest_transport.py

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import contextlib
12
import json
23
import logging
34

@@ -63,11 +64,14 @@ def __init__(
6364
httpx_client: httpx.AsyncClient,
6465
agent_card: AgentCard | None,
6566
url: str,
67+
subscribe_method_override: str | None = None,
6668
):
6769
"""Initializes the CompatRestTransport."""
6870
self.url = url.removesuffix('/')
6971
self.httpx_client = httpx_client
7072
self.agent_card = agent_card
73+
self._subscribe_method_override = subscribe_method_override
74+
self._subscribe_auto_method_override = subscribe_method_override is None
7175

7276
async def send_message(
7377
self,
@@ -273,13 +277,41 @@ async def subscribe(
273277
*,
274278
context: ClientCallContext | None = None,
275279
) -> AsyncGenerator[StreamResponse]:
276-
"""Reconnects to get task updates."""
277-
async for event in self._send_stream_request(
278-
'GET',
279-
f'/v1/tasks/{request.id}:subscribe',
280-
context=context,
281-
):
282-
yield event
280+
"""Reconnects to get task updates.
281+
282+
This method implements backward compatibility logic for the subscribe
283+
endpoint. It first attempts to use POST, which is the official method
284+
for A2A subscribe endpoint. If the server returns 405 Method Not Allowed,
285+
it falls back to GET and remembers this preference for future calls
286+
on this transport instance. If both fail with 405, it will default back
287+
to POST for next calls but will not retry again.
288+
"""
289+
subscribe_method = self._subscribe_method_override or 'POST'
290+
try:
291+
async for event in self._send_stream_request(
292+
subscribe_method,
293+
f'/v1/tasks/{request.id}:subscribe',
294+
context=context,
295+
):
296+
yield event
297+
except A2AClientError as e:
298+
# Check for 405 Method Not Allowed in the cause (httpx.HTTPStatusError)
299+
cause = e.__cause__
300+
if (
301+
isinstance(cause, httpx.HTTPStatusError)
302+
and cause.response.status_code == httpx.codes.METHOD_NOT_ALLOWED
303+
):
304+
if self._subscribe_method_override:
305+
if self._subscribe_auto_method_override:
306+
self._subscribe_auto_method_override = False
307+
self._subscribe_method_override = 'POST'
308+
raise
309+
else:
310+
self._subscribe_method_override = 'GET'
311+
async for event in self.subscribe(request, context=context):
312+
yield event
313+
else:
314+
raise
283315

284316
async def get_extended_agent_card(
285317
self,
@@ -311,7 +343,14 @@ async def close(self) -> None:
311343
def _handle_http_error(self, e: httpx.HTTPStatusError) -> NoReturn:
312344
"""Handles HTTP status errors and raises the appropriate A2AError."""
313345
try:
314-
error_data = e.response.json()
346+
with contextlib.suppress(httpx.StreamClosed):
347+
e.response.read()
348+
349+
try:
350+
error_data = e.response.json()
351+
except (json.JSONDecodeError, ValueError, httpx.ResponseNotRead):
352+
error_data = {}
353+
315354
error_type = error_data.get('type')
316355
message = error_data.get('message', str(e))
317356

src/a2a/server/apps/rest/rest_adapter.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,10 @@ def routes(self) -> dict[tuple[str, str], Callable[[Request], Any]]:
237237
self._handle_streaming_request,
238238
self.handler.on_subscribe_to_task,
239239
),
240+
('/tasks/{id}:subscribe', 'POST'): functools.partial(
241+
self._handle_streaming_request,
242+
self.handler.on_subscribe_to_task,
243+
),
240244
('/tasks/{id}', 'GET'): functools.partial(
241245
self._handle_request, self.handler.on_get_task
242246
),

tests/client/transports/test_rest_client.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -730,8 +730,15 @@ async def empty_aiter():
730730
async for _ in method(request=request_obj):
731731
pass
732732

733-
# 4. Verify the URL
733+
# 4. Verify the URL and method
734734
mock_aconnect_sse.assert_called_once()
735-
args, _ = mock_aconnect_sse.call_args
735+
args, kwargs = mock_aconnect_sse.call_args
736+
# method is 2nd positional argument
737+
assert args[1] == 'POST'
738+
if method_name == 'subscribe':
739+
assert kwargs.get('json') is None
740+
else:
741+
assert kwargs.get('json') == json_format.MessageToDict(request_obj)
742+
736743
# url is 3rd positional argument in aconnect_sse(client, method, url, ...)
737744
assert args[2] == f'http://agent.example.com/api{expected_path}'

tests/compat/v0_3/test_rest_handler.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,44 @@ async def mock_stream(*args, **kwargs):
186186
]
187187

188188

189+
@pytest.mark.anyio
190+
async def test_on_subscribe_to_task_post(
191+
rest_handler, mock_request, mock_context
192+
):
193+
mock_request.path_params = {'id': 'task-1'}
194+
mock_request.method = 'POST'
195+
request_body = {'name': 'tasks/task-1'}
196+
mock_request.body = AsyncMock(
197+
return_value=json.dumps(request_body).encode('utf-8')
198+
)
199+
200+
async def mock_stream(*args, **kwargs):
201+
yield types_v03.SendStreamingMessageSuccessResponse(
202+
id='req-1',
203+
result=types_v03.Message(
204+
message_id='msg-2',
205+
role='agent',
206+
parts=[types_v03.TextPart(text='Update')],
207+
),
208+
)
209+
210+
rest_handler.handler03.on_subscribe_to_task = MagicMock(
211+
side_effect=mock_stream
212+
)
213+
214+
results = [
215+
chunk
216+
async for chunk in rest_handler.on_subscribe_to_task(
217+
mock_request, mock_context
218+
)
219+
]
220+
221+
assert len(results) == 1
222+
rest_handler.handler03.on_subscribe_to_task.assert_called_once()
223+
called_req = rest_handler.handler03.on_subscribe_to_task.call_args[0][0]
224+
assert called_req.params.id == 'task-1'
225+
226+
189227
@pytest.mark.anyio
190228
async def test_get_push_notification(rest_handler, mock_request, mock_context):
191229
mock_request.path_params = {'id': 'task-1', 'push_id': 'push-1'}

0 commit comments

Comments
 (0)