Skip to content

Commit 4cb68aa

Browse files
authored
feat(compat): set a2a-version header to 1.0.0 (#764)
# Description Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [X] Follow the [`CONTRIBUTING` Guide](https://github.com/a2aproject/a2a-python/blob/main/CONTRIBUTING.md). - [X] Make your Pull Request title in the <https://www.conventionalcommits.org/> specification. - Important Prefixes for [release-please](https://github.com/googleapis/release-please): - `fix:` which represents bug fixes, and correlates to a [SemVer](https://semver.org/) patch. - `feat:` represents a new feature, and correlates to a SemVer minor. - `feat!:`, or `fix!:`, `refactor!:`, etc., which represent a breaking change (indicated by the `!`) and will result in a SemVer major. - [X] Ensure the tests and linter pass (Run `bash scripts/format.sh` from the repository root to format) - [X] Appropriate docs were updated (if necessary)
1 parent 81f3494 commit 4cb68aa

5 files changed

Lines changed: 78 additions & 24 deletions

File tree

src/a2a/client/client_factory.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import logging
44

55
from collections.abc import Callable
6-
from typing import Any
6+
from typing import Any, cast
77

88
import httpx
99

@@ -20,6 +20,8 @@
2020
AgentInterface,
2121
)
2222
from a2a.utils.constants import (
23+
PROTOCOL_VERSION_CURRENT,
24+
VERSION_HEADER,
2325
TransportProtocol,
2426
)
2527

@@ -65,18 +67,24 @@ def __init__(
6567
):
6668
if consumers is None:
6769
consumers = []
70+
71+
client = config.httpx_client or httpx.AsyncClient()
72+
client.headers.setdefault(VERSION_HEADER, PROTOCOL_VERSION_CURRENT)
73+
config.httpx_client = client
74+
6875
self._config = config
6976
self._consumers = consumers
7077
self._registry: dict[str, TransportProducer] = {}
7178
self._register_defaults(config.supported_protocol_bindings)
7279

7380
def _register_defaults(self, supported: list[str]) -> None:
7481
# Empty support list implies JSON-RPC only.
82+
7583
if TransportProtocol.JSONRPC in supported or not supported:
7684
self.register(
7785
TransportProtocol.JSONRPC,
7886
lambda card, url, config, interceptors: JsonRpcTransport(
79-
config.httpx_client or httpx.AsyncClient(),
87+
cast('httpx.AsyncClient', config.httpx_client),
8088
card,
8189
url,
8290
interceptors,
@@ -87,7 +95,7 @@ def _register_defaults(self, supported: list[str]) -> None:
8795
self.register(
8896
TransportProtocol.HTTP_JSON,
8997
lambda card, url, config, interceptors: RestTransport(
90-
config.httpx_client or httpx.AsyncClient(),
98+
cast('httpx.AsyncClient', config.httpx_client),
9199
card,
92100
url,
93101
interceptors,

src/a2a/client/transports/grpc.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
Task,
4444
TaskPushNotificationConfig,
4545
)
46+
from a2a.utils.constants import PROTOCOL_VERSION_CURRENT, VERSION_HEADER
4647
from a2a.utils.telemetry import SpanKind, trace_class
4748

4849

@@ -303,11 +304,14 @@ async def close(self) -> None:
303304
def _get_grpc_metadata(
304305
self,
305306
extensions: list[str] | None = None,
306-
) -> list[tuple[str, str]] | None:
307+
) -> list[tuple[str, str]]:
307308
"""Creates gRPC metadata for extensions."""
309+
metadata = [(VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT)]
310+
308311
extensions_to_use = extensions or self.extensions
309312
if extensions_to_use:
310-
return [
313+
metadata.append(
311314
(HTTP_EXTENSION_HEADER.lower(), ','.join(extensions_to_use))
312-
]
313-
return None
315+
)
316+
317+
return metadata

src/a2a/utils/constants.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,7 @@ class TransportProtocol(str, Enum):
2222

2323
DEFAULT_MAX_CONTENT_LENGTH = 10 * 1024 * 1024 # 10MB
2424
JSONRPC_PARSE_ERROR_CODE = -32700
25+
VERSION_HEADER = 'A2A-Version'
26+
27+
PROTOCOL_VERSION_1_0 = '1.0'
28+
PROTOCOL_VERSION_CURRENT = PROTOCOL_VERSION_1_0

tests/client/transports/test_grpc_client.py

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from a2a.client.transports.grpc import GrpcTransport
77
from a2a.extensions.common import HTTP_EXTENSION_HEADER
8+
from a2a.utils.constants import VERSION_HEADER, PROTOCOL_VERSION_CURRENT
89
from a2a.types import a2a_pb2
910
from a2a.types.a2a_pb2 import (
1011
AgentCapabilities,
@@ -217,10 +218,11 @@ async def test_send_message_task_response(
217218
mock_grpc_stub.SendMessage.assert_awaited_once()
218219
_, kwargs = mock_grpc_stub.SendMessage.call_args
219220
assert kwargs['metadata'] == [
221+
(VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT),
220222
(
221223
HTTP_EXTENSION_HEADER.lower(),
222224
'https://example.com/test-ext/v3',
223-
)
225+
),
224226
]
225227
assert response.HasField('task')
226228
assert response.task.id == sample_task.id
@@ -266,10 +268,11 @@ async def test_send_message_message_response(
266268
mock_grpc_stub.SendMessage.assert_awaited_once()
267269
_, kwargs = mock_grpc_stub.SendMessage.call_args
268270
assert kwargs['metadata'] == [
271+
(VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT),
269272
(
270273
HTTP_EXTENSION_HEADER.lower(),
271274
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
272-
)
275+
),
273276
]
274277
assert response.HasField('message')
275278
assert response.message.message_id == sample_message.message_id
@@ -315,10 +318,11 @@ async def test_send_message_streaming( # noqa: PLR0913
315318
mock_grpc_stub.SendStreamingMessage.assert_called_once()
316319
_, kwargs = mock_grpc_stub.SendStreamingMessage.call_args
317320
assert kwargs['metadata'] == [
321+
(VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT),
318322
(
319323
HTTP_EXTENSION_HEADER.lower(),
320324
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
321-
)
325+
),
322326
]
323327
# Responses are StreamResponse proto objects
324328
assert responses[0].HasField('message')
@@ -350,10 +354,11 @@ async def test_get_task(
350354
mock_grpc_stub.GetTask.assert_awaited_once_with(
351355
a2a_pb2.GetTaskRequest(id=f'{sample_task.id}', history_length=None),
352356
metadata=[
357+
(VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT),
353358
(
354359
HTTP_EXTENSION_HEADER.lower(),
355360
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
356-
)
361+
),
357362
],
358363
)
359364
assert response.id == sample_task.id
@@ -378,10 +383,11 @@ async def test_list_tasks(
378383
mock_grpc_stub.ListTasks.assert_awaited_once_with(
379384
params,
380385
metadata=[
386+
(VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT),
381387
(
382388
HTTP_EXTENSION_HEADER.lower(),
383389
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
384-
)
390+
),
385391
],
386392
)
387393
assert result.total_size == 2
@@ -405,10 +411,11 @@ async def test_get_task_with_history(
405411
id=f'{sample_task.id}', history_length=history_len
406412
),
407413
metadata=[
414+
(VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT),
408415
(
409416
HTTP_EXTENSION_HEADER.lower(),
410417
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
411-
)
418+
),
412419
],
413420
)
414421

@@ -433,7 +440,8 @@ async def test_cancel_task(
433440
mock_grpc_stub.CancelTask.assert_awaited_once_with(
434441
a2a_pb2.CancelTaskRequest(id=f'{sample_task.id}'),
435442
metadata=[
436-
(HTTP_EXTENSION_HEADER.lower(), 'https://example.com/test-ext/v3')
443+
(VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT),
444+
(HTTP_EXTENSION_HEADER.lower(), 'https://example.com/test-ext/v3'),
437445
],
438446
)
439447
assert response.status.state == TaskState.TASK_STATE_CANCELED
@@ -462,10 +470,11 @@ async def test_create_task_push_notification_config_with_valid_task(
462470
mock_grpc_stub.CreateTaskPushNotificationConfig.assert_awaited_once_with(
463471
request,
464472
metadata=[
473+
(VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT),
465474
(
466475
HTTP_EXTENSION_HEADER.lower(),
467476
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
468-
)
477+
),
469478
],
470479
)
471480
assert response.task_id == sample_task_push_notification_config.task_id
@@ -524,10 +533,11 @@ async def test_get_task_push_notification_config_with_valid_task(
524533
id=config_id,
525534
),
526535
metadata=[
536+
(VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT),
527537
(
528538
HTTP_EXTENSION_HEADER.lower(),
529539
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
530-
)
540+
),
531541
],
532542
)
533543
assert response.task_id == sample_task_push_notification_config.task_id
@@ -577,10 +587,11 @@ async def test_list_task_push_notification_configs(
577587
mock_grpc_stub.ListTaskPushNotificationConfigs.assert_awaited_once_with(
578588
a2a_pb2.ListTaskPushNotificationConfigsRequest(task_id='task-1'),
579589
metadata=[
590+
(VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT),
580591
(
581592
HTTP_EXTENSION_HEADER.lower(),
582593
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
583-
)
594+
),
584595
],
585596
)
586597
assert len(response.configs) == 1
@@ -609,10 +620,11 @@ async def test_delete_task_push_notification_config(
609620
id='config-1',
610621
),
611622
metadata=[
623+
(VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT),
612624
(
613625
HTTP_EXTENSION_HEADER.lower(),
614626
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
615-
)
627+
),
616628
],
617629
)
618630

@@ -623,32 +635,47 @@ async def test_delete_task_push_notification_config(
623635
(
624636
None,
625637
None,
626-
None,
638+
[(VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT)],
627639
), # Case 1: No initial, No input
628640
(
629641
['ext1'],
630642
None,
631-
[(HTTP_EXTENSION_HEADER.lower(), 'ext1')],
643+
[
644+
(VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT),
645+
(HTTP_EXTENSION_HEADER.lower(), 'ext1'),
646+
],
632647
), # Case 2: Initial, No input
633648
(
634649
None,
635650
['ext2'],
636-
[(HTTP_EXTENSION_HEADER.lower(), 'ext2')],
651+
[
652+
(VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT),
653+
(HTTP_EXTENSION_HEADER.lower(), 'ext2'),
654+
],
637655
), # Case 3: No initial, Input
638656
(
639657
['ext1'],
640658
['ext2'],
641-
[(HTTP_EXTENSION_HEADER.lower(), 'ext2')],
659+
[
660+
(VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT),
661+
(HTTP_EXTENSION_HEADER.lower(), 'ext2'),
662+
],
642663
), # Case 4: Initial, Input (override)
643664
(
644665
['ext1'],
645666
['ext2', 'ext3'],
646-
[(HTTP_EXTENSION_HEADER.lower(), 'ext2,ext3')],
667+
[
668+
(VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT),
669+
(HTTP_EXTENSION_HEADER.lower(), 'ext2,ext3'),
670+
],
647671
), # Case 5: Initial, Multiple inputs (override)
648672
(
649673
['ext1', 'ext2'],
650674
['ext3'],
651-
[(HTTP_EXTENSION_HEADER.lower(), 'ext3')],
675+
[
676+
(VERSION_HEADER.lower(), PROTOCOL_VERSION_CURRENT),
677+
(HTTP_EXTENSION_HEADER.lower(), 'ext3'),
678+
],
652679
), # Case 6: Multiple initial, Single input (override)
653680
],
654681
)

tests/utils/test_constants.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,14 @@ def test_agent_card_constants():
1313
def test_default_rpc_url():
1414
"""Test default RPC URL constant."""
1515
assert constants.DEFAULT_RPC_URL == '/'
16+
17+
18+
def test_version_header():
19+
"""Test version header constant."""
20+
assert constants.VERSION_HEADER == 'A2A-Version'
21+
22+
23+
def test_protocol_versions():
24+
"""Test protocol version constants."""
25+
assert constants.PROTOCOL_VERSION_1_0 == '1.0'
26+
assert constants.PROTOCOL_VERSION_CURRENT == '1.0'

0 commit comments

Comments
 (0)