Skip to content

Commit 08c491e

Browse files
authored
feat(server): add v0.3 legacy compatibility for database models (#783)
## Description Implements a mechanism to handle legacy v0.3 data stored in the database. When a `Task` or a `TaskPushNotificationConfig` does not have a `protocol_version` set to `0.1`, it validates and converts the data using v0.3 Pydantic models and conversion utilities. This ensures backward compatibility for existing records containing string-based enums and old field structures. ## Changes - The `status`, `artifacts`, and `history` fields in `TaskMixin` now use standard SQLAlchemy JSON columns with explicit Python type hints (Mapped[Any], Mapped[list[Any]]). - Removed `PydanticType` and `PydanticListType`: These custom SQLAlchemy types are no longer needed as serialization is now handled at the Store level. - Updated `_to_orm` to use MessageToDict on the entire Task object - Updated `_from_orm`: - v1.0: Uses ParseDict - Legacy (v0.3): Uses Pydantic's model_validate to reconstruct the legacy Task tree before converting to core types. - Updated Tests: - Removed obsolete tests for the deleted Pydantic type - updated the Task Store integration tests to verify the new `0.3 type to 1.0 type` conversion logic ## Contributing Guide - [x] Follow the [`CONTRIBUTING` Guide](https://github.com/a2aproject/a2a-python/blob/main/CONTRIBUTING.md). - [x] Make your Pull Request title in the <https://www.conventionalcommits.org/> specification. - Important Prefixes for [release-please](https://github.com/googleapis/release-please): - `fix:` which represents bug fixes, and correlates to a [SemVer](https://semver.org/) patch. - `feat:` represents a new feature, and correlates to a SemVer minor. - `feat!:`, or `fix!:`, `refactor!:`, etc., which represent a breaking change (indicated by the `!`) and will result in a SemVer major. - [x] Ensure the tests and linter pass (Run `bash scripts/format.sh` from the repository root to format) - [x] Appropriate docs were updated (if necessary) Fixes #715 🦕
1 parent 45b3059 commit 08c491e

6 files changed

Lines changed: 296 additions & 207 deletions

File tree

src/a2a/server/models.py

Lines changed: 7 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from datetime import datetime
2-
from typing import TYPE_CHECKING, Any, Generic, TypeVar
2+
from typing import TYPE_CHECKING, Any
33

44

55
if TYPE_CHECKING:
@@ -11,24 +11,14 @@ def override(func): # noqa: ANN001, ANN201
1111
return func
1212

1313

14-
from google.protobuf.json_format import MessageToDict, ParseDict
15-
from google.protobuf.message import Message as ProtoMessage
16-
from pydantic import BaseModel
17-
18-
from a2a.types.a2a_pb2 import Artifact, Message, TaskStatus
19-
20-
2114
try:
22-
from sqlalchemy import JSON, DateTime, Dialect, Index, LargeBinary, String
15+
from sqlalchemy import JSON, DateTime, Index, LargeBinary, String
2316
from sqlalchemy.orm import (
2417
DeclarativeBase,
2518
Mapped,
2619
declared_attr,
2720
mapped_column,
2821
)
29-
from sqlalchemy.types import (
30-
TypeDecorator,
31-
)
3222
except ImportError as e:
3323
raise ImportError(
3424
'Database models require SQLAlchemy. '
@@ -40,101 +30,6 @@ def override(func): # noqa: ANN001, ANN201
4030
) from e
4131

4232

43-
T = TypeVar('T')
44-
45-
46-
class PydanticType(TypeDecorator[T], Generic[T]):
47-
"""SQLAlchemy type that handles Pydantic model and Protobuf message serialization."""
48-
49-
impl = JSON
50-
cache_ok = True
51-
52-
def __init__(self, pydantic_type: type[T], **kwargs: dict[str, Any]):
53-
"""Initialize the PydanticType.
54-
55-
Args:
56-
pydantic_type: The Pydantic model or Protobuf message type to handle.
57-
**kwargs: Additional arguments for TypeDecorator.
58-
"""
59-
self.pydantic_type = pydantic_type
60-
super().__init__(**kwargs)
61-
62-
def process_bind_param(
63-
self, value: T | None, dialect: Dialect
64-
) -> dict[str, Any] | None:
65-
"""Convert Pydantic model or Protobuf message to a JSON-serializable dictionary for the database."""
66-
if value is None:
67-
return None
68-
if isinstance(value, ProtoMessage):
69-
return MessageToDict(value, preserving_proto_field_name=False)
70-
if isinstance(value, BaseModel):
71-
return value.model_dump(mode='json')
72-
return value # type: ignore[return-value]
73-
74-
def process_result_value(
75-
self, value: dict[str, Any] | None, dialect: Dialect
76-
) -> T | None:
77-
"""Convert a JSON-like dictionary from the database back to a Pydantic model or Protobuf message."""
78-
if value is None:
79-
return None
80-
# Check if it's a protobuf message class
81-
if isinstance(self.pydantic_type, type) and issubclass(
82-
self.pydantic_type, ProtoMessage
83-
):
84-
return ParseDict(value, self.pydantic_type()) # type: ignore[return-value]
85-
# Assume it's a Pydantic model
86-
return self.pydantic_type.model_validate(value) # type: ignore[attr-defined]
87-
88-
89-
class PydanticListType(TypeDecorator, Generic[T]):
90-
"""SQLAlchemy type that handles lists of Pydantic models or Protobuf messages."""
91-
92-
impl = JSON
93-
cache_ok = True
94-
95-
def __init__(self, pydantic_type: type[T], **kwargs: dict[str, Any]):
96-
"""Initialize the PydanticListType.
97-
98-
Args:
99-
pydantic_type: The Pydantic model or Protobuf message type for items in the list.
100-
**kwargs: Additional arguments for TypeDecorator.
101-
"""
102-
self.pydantic_type = pydantic_type
103-
super().__init__(**kwargs)
104-
105-
def process_bind_param(
106-
self, value: list[T] | None, dialect: Dialect
107-
) -> list[dict[str, Any]] | None:
108-
"""Convert a list of Pydantic models or Protobuf messages to a JSON-serializable list for the DB."""
109-
if value is None:
110-
return None
111-
result: list[dict[str, Any]] = []
112-
for item in value:
113-
if isinstance(item, ProtoMessage):
114-
result.append(
115-
MessageToDict(item, preserving_proto_field_name=False)
116-
)
117-
elif isinstance(item, BaseModel):
118-
result.append(item.model_dump(mode='json'))
119-
else:
120-
result.append(item) # type: ignore[arg-type]
121-
return result
122-
123-
def process_result_value(
124-
self, value: list[dict[str, Any]] | None, dialect: Dialect
125-
) -> list[T] | None:
126-
"""Convert a JSON-like list from the DB back to a list of Pydantic models or Protobuf messages."""
127-
if value is None:
128-
return None
129-
# Check if it's a protobuf message class
130-
if isinstance(self.pydantic_type, type) and issubclass(
131-
self.pydantic_type, ProtoMessage
132-
):
133-
return [ParseDict(item, self.pydantic_type()) for item in value] # type: ignore[misc]
134-
# Assume it's a Pydantic model
135-
return [self.pydantic_type.model_validate(item) for item in value] # type: ignore[attr-defined]
136-
137-
13833
# Base class for all database models
13934
class Base(DeclarativeBase):
14035
"""Base class for declarative models in A2A SDK."""
@@ -153,14 +48,12 @@ class TaskMixin:
15348
last_updated: Mapped[datetime | None] = mapped_column(
15449
DateTime, nullable=True
15550
)
156-
157-
# Properly typed Pydantic fields with automatic serialization
158-
status: Mapped[TaskStatus] = mapped_column(PydanticType(TaskStatus))
159-
artifacts: Mapped[list[Artifact] | None] = mapped_column(
160-
PydanticListType(Artifact), nullable=True
51+
status: Mapped[dict[str, Any] | None] = mapped_column(JSON, nullable=True)
52+
artifacts: Mapped[list[dict[str, Any]] | None] = mapped_column(
53+
JSON, nullable=True
16154
)
162-
history: Mapped[list[Message] | None] = mapped_column(
163-
PydanticListType(Message), nullable=True
55+
history: Mapped[list[dict[str, Any]] | None] = mapped_column(
56+
JSON, nullable=True
16457
)
16558
protocol_version: Mapped[str | None] = mapped_column(
16659
String(16), nullable=True

src/a2a/server/tasks/database_push_notification_config_store.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# ruff: noqa: PLC0415
2-
import json
32
import logging
43

54
from typing import TYPE_CHECKING
@@ -27,6 +26,8 @@
2726
"or 'pip install a2a-sdk[sql]'"
2827
) from e
2928

29+
from a2a.compat.v0_3 import conversions
30+
from a2a.compat.v0_3 import types as types_v03
3031
from a2a.server.context import ServerCallContext
3132
from a2a.server.models import (
3233
Base,
@@ -163,6 +164,7 @@ def _to_orm(
163164
config_id=config.id,
164165
owner=owner,
165166
config_data=data_to_store,
167+
protocol_version='1.0',
166168
)
167169

168170
def _from_orm(
@@ -181,11 +183,11 @@ def _from_orm(
181183

182184
try:
183185
decrypted_payload = self._fernet.decrypt(payload)
184-
return Parse(
186+
return self._parse_config(
185187
decrypted_payload.decode('utf-8'),
186-
TaskPushNotificationConfig(),
188+
model_instance.protocol_version,
187189
)
188-
except (json.JSONDecodeError, Exception) as e:
190+
except Exception as e:
189191
if isinstance(e, InvalidToken):
190192
# Decryption failed. This could be because the data is not encrypted.
191193
# We'll log a warning and try to parse it as plain JSON as a fallback.
@@ -215,7 +217,10 @@ def _from_orm(
215217
if isinstance(payload, bytes)
216218
else payload
217219
)
218-
return Parse(payload_str, TaskPushNotificationConfig())
220+
return self._parse_config(
221+
payload_str, model_instance.protocol_version
222+
)
223+
219224
except Exception as e:
220225
if self._fernet:
221226
logger.exception(
@@ -334,3 +339,22 @@ async def delete_info(
334339
owner,
335340
config_id,
336341
)
342+
343+
def _parse_config(
344+
self, json_payload: str, protocol_version: str | None = None
345+
) -> TaskPushNotificationConfig:
346+
"""Parses a JSON payload into a TaskPushNotificationConfig proto.
347+
348+
Uses protocol_version to decide between modern parsing and legacy conversion.
349+
"""
350+
if protocol_version == '1.0':
351+
return Parse(json_payload, TaskPushNotificationConfig())
352+
353+
legacy_instance = (
354+
types_v03.TaskPushNotificationConfig.model_validate_json(
355+
json_payload
356+
)
357+
)
358+
return conversions.to_core_task_push_notification_config(
359+
legacy_instance
360+
)

src/a2a/server/tasks/database_task_store.py

Lines changed: 56 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@
3131
"or 'pip install a2a-sdk[sql]'"
3232
) from e
3333

34-
from google.protobuf.json_format import MessageToDict
34+
from google.protobuf.json_format import MessageToDict, ParseDict
3535

36+
from a2a.compat.v0_3 import conversions
37+
from a2a.compat.v0_3 import types as types_v03
3638
from a2a.server.context import ServerCallContext
3739
from a2a.server.models import Base, TaskModel, create_task_model
3840
from a2a.server.owner_resolver import OwnerResolver, resolve_user_scope
@@ -117,45 +119,78 @@ async def _ensure_initialized(self) -> None:
117119

118120
def _to_orm(self, task: Task, owner: str) -> TaskModel:
119121
"""Maps a Proto Task to a SQLAlchemy TaskModel instance."""
120-
# Pass proto objects directly - PydanticType/PydanticListType
121-
# handle serialization via process_bind_param
122122
return self.task_model(
123123
id=task.id,
124124
context_id=task.context_id,
125125
kind='task', # Default kind for tasks
126126
owner=owner,
127127
last_updated=(
128128
task.status.timestamp.ToDatetime()
129-
if task.HasField('status') and task.status.HasField('timestamp')
129+
if task.status.HasField('timestamp')
130130
else None
131131
),
132-
status=task.status if task.HasField('status') else None,
133-
artifacts=list(task.artifacts) if task.artifacts else [],
134-
history=list(task.history) if task.history else [],
132+
status=MessageToDict(task.status),
133+
artifacts=[MessageToDict(artifact) for artifact in task.artifacts],
134+
history=[MessageToDict(history) for history in task.history],
135135
task_metadata=(
136136
MessageToDict(task.metadata) if task.metadata.fields else None
137137
),
138+
protocol_version='1.0',
138139
)
139140

140141
def _from_orm(self, task_model: TaskModel) -> Task:
141142
"""Maps a SQLAlchemy TaskModel to a Proto Task instance."""
142-
# PydanticType/PydanticListType already deserialize to proto objects
143-
# via process_result_value, so we can construct the Task directly
144-
task = Task(
143+
if task_model.protocol_version == '1.0':
144+
task = Task(
145+
id=task_model.id,
146+
context_id=task_model.context_id,
147+
)
148+
if task_model.status:
149+
ParseDict(
150+
cast('dict[str, Any]', task_model.status), task.status
151+
)
152+
if task_model.artifacts:
153+
for art_dict in cast(
154+
'list[dict[str, Any]]', task_model.artifacts
155+
):
156+
art = task.artifacts.add()
157+
ParseDict(art_dict, art)
158+
if task_model.history:
159+
for msg_dict in cast(
160+
'list[dict[str, Any]]', task_model.history
161+
):
162+
msg = task.history.add()
163+
ParseDict(msg_dict, msg)
164+
if task_model.task_metadata:
165+
task.metadata.update(
166+
cast('dict[str, Any]', task_model.task_metadata)
167+
)
168+
return task
169+
170+
# Legacy conversion
171+
legacy_task = types_v03.Task(
145172
id=task_model.id,
146173
context_id=task_model.context_id,
174+
status=types_v03.TaskStatus.model_validate(task_model.status),
175+
artifacts=(
176+
[
177+
types_v03.Artifact.model_validate(a)
178+
for a in task_model.artifacts
179+
]
180+
if task_model.artifacts
181+
else []
182+
),
183+
history=(
184+
[
185+
types_v03.Message.model_validate(m)
186+
for m in task_model.history
187+
]
188+
if task_model.history
189+
else []
190+
),
191+
metadata=task_model.task_metadata or {},
147192
)
148-
if task_model.status:
149-
task.status.CopyFrom(task_model.status)
150-
if task_model.artifacts:
151-
task.artifacts.extend(task_model.artifacts)
152-
if task_model.history:
153-
task.history.extend(task_model.history)
154-
if task_model.task_metadata:
155-
task.metadata.update(
156-
cast('dict[str, Any]', task_model.task_metadata)
157-
)
158-
return task
193+
return conversions.to_core_task(legacy_task)
159194

160195
async def save(
161196
self, task: Task, context: ServerCallContext | None = None

0 commit comments

Comments
 (0)