|
3 | 3 |
|
4 | 4 | import httpx |
5 | 5 | import pytest |
| 6 | +import respx |
6 | 7 |
|
| 8 | +from google.protobuf.json_format import MessageToJson |
7 | 9 | from httpx_sse import EventSource, ServerSentEvent |
8 | 10 |
|
9 | 11 | from a2a.client import create_text_message_object |
10 | 12 | from a2a.client.errors import A2AClientHTTPError |
11 | 13 | from a2a.client.transports.rest import RestTransport |
12 | 14 | from a2a.extensions.common import HTTP_EXTENSION_HEADER |
| 15 | +from a2a.grpc import a2a_pb2 |
13 | 16 | from a2a.types import ( |
14 | 17 | AgentCapabilities, |
15 | 18 | AgentCard, |
16 | 19 | MessageSendParams, |
| 20 | + Role, |
17 | 21 | ) |
| 22 | +from a2a.utils import proto_utils |
18 | 23 |
|
19 | 24 |
|
20 | 25 | @pytest.fixture |
@@ -88,6 +93,64 @@ async def test_send_message_with_default_extensions( |
88 | 93 | }, |
89 | 94 | ) |
90 | 95 |
|
| 96 | + # Repro of https://github.com/a2aproject/a2a-python/issues/540 |
| 97 | + @pytest.mark.asyncio |
| 98 | + @respx.mock |
| 99 | + async def test_send_message_streaming_comment_success( |
| 100 | + self, |
| 101 | + mock_agent_card: MagicMock, |
| 102 | + ): |
| 103 | + """Test that SSE comments are ignored.""" |
| 104 | + async with httpx.AsyncClient() as client: |
| 105 | + transport = RestTransport( |
| 106 | + httpx_client=client, agent_card=mock_agent_card |
| 107 | + ) |
| 108 | + params = MessageSendParams( |
| 109 | + message=create_text_message_object(content='Hello stream') |
| 110 | + ) |
| 111 | + |
| 112 | + mock_stream_response_1 = a2a_pb2.StreamResponse( |
| 113 | + msg=proto_utils.ToProto.message( |
| 114 | + create_text_message_object( |
| 115 | + content='First part', role=Role.agent |
| 116 | + ) |
| 117 | + ) |
| 118 | + ) |
| 119 | + mock_stream_response_2 = a2a_pb2.StreamResponse( |
| 120 | + msg=proto_utils.ToProto.message( |
| 121 | + create_text_message_object( |
| 122 | + content='Second part', role=Role.agent |
| 123 | + ) |
| 124 | + ) |
| 125 | + ) |
| 126 | + |
| 127 | + sse_content = ( |
| 128 | + 'id: stream_id_1\n' |
| 129 | + f'data: {MessageToJson(mock_stream_response_1, indent=None)}\n\n' |
| 130 | + ': keep-alive\n\n' |
| 131 | + 'id: stream_id_2\n' |
| 132 | + f'data: {MessageToJson(mock_stream_response_2, indent=None)}\n\n' |
| 133 | + ': keep-alive\n\n' |
| 134 | + ) |
| 135 | + |
| 136 | + respx.post( |
| 137 | + f'{mock_agent_card.url.rstrip("/")}/v1/message:stream' |
| 138 | + ).mock( |
| 139 | + return_value=httpx.Response( |
| 140 | + 200, |
| 141 | + headers={'Content-Type': 'text/event-stream'}, |
| 142 | + content=sse_content, |
| 143 | + ) |
| 144 | + ) |
| 145 | + |
| 146 | + results = [] |
| 147 | + async for item in transport.send_message_streaming(request=params): |
| 148 | + results.append(item) |
| 149 | + |
| 150 | + assert len(results) == 2 |
| 151 | + assert results[0].parts[0].root.text == 'First part' |
| 152 | + assert results[1].parts[0].root.text == 'Second part' |
| 153 | + |
91 | 154 | @pytest.mark.asyncio |
92 | 155 | @patch('a2a.client.transports.rest.aconnect_sse') |
93 | 156 | async def test_send_message_streaming_with_new_extensions( |
|
0 commit comments