Skip to content

Commit 7e472f6

Browse files
Thibault-Pelletierjourdain
authored andcommitted
fix(typed_state): fix support for data class in collections
- Fix encoding / decoding of dataclasses nested in lists or dicts
1 parent e35da90 commit 7e472f6

2 files changed

Lines changed: 104 additions & 0 deletions

File tree

tests/test_typed_state.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,3 +437,89 @@ def decode(self, _obj, _obj_type: type):
437437

438438
with pytest.raises(TypeError):
439439
print(typed_state.data.my_enum)
440+
441+
442+
@dataclass
443+
class SimpleTypes:
444+
my_int: int
445+
my_enum: MyEnum
446+
my_path: Path
447+
448+
449+
@dataclass
450+
class TypedComposite:
451+
simple_types: SimpleTypes = field(default_factory=SimpleTypes)
452+
453+
454+
@dataclass
455+
class DataclassCollections:
456+
nested_list: list[TypedComposite] = field(default_factory=list)
457+
nested_dict: dict[str, TypedComposite] = field(default_factory=dict)
458+
459+
460+
def test_encode_decode_supports_collections_of_nested_dataclass(state):
461+
typed_state = TypedState(state, DataclassCollections)
462+
typed_state.data.nested_list = [
463+
TypedComposite(
464+
SimpleTypes(my_int=1, my_enum=MyEnum.A, my_path=Path("/path/to/1"))
465+
),
466+
TypedComposite(
467+
SimpleTypes(my_int=2, my_enum=MyEnum.B, my_path=Path("/path/to/2"))
468+
),
469+
]
470+
471+
typed_state.data.nested_dict = {
472+
"3": TypedComposite(
473+
SimpleTypes(my_int=3, my_enum=MyEnum.C, my_path=Path("/path/to/3"))
474+
),
475+
"4": TypedComposite(
476+
SimpleTypes(my_int=4, my_enum=MyEnum.A, my_path=Path("/path/to/4"))
477+
),
478+
}
479+
480+
assert typed_state.data.nested_list[0] == TypedComposite(
481+
SimpleTypes(my_int=1, my_enum=MyEnum.A, my_path=Path("/path/to/1"))
482+
)
483+
assert typed_state.data.nested_list[1] == TypedComposite(
484+
SimpleTypes(my_int=2, my_enum=MyEnum.B, my_path=Path("/path/to/2"))
485+
)
486+
assert typed_state.data.nested_dict["3"] == TypedComposite(
487+
SimpleTypes(my_int=3, my_enum=MyEnum.C, my_path=Path("/path/to/3"))
488+
)
489+
assert typed_state.data.nested_dict["4"] == TypedComposite(
490+
SimpleTypes(my_int=4, my_enum=MyEnum.A, my_path=Path("/path/to/4"))
491+
)
492+
493+
assert state[typed_state.name.nested_list] == [
494+
{
495+
"simple_types": {
496+
"my_int": 1,
497+
"my_enum": typed_state.encode(MyEnum.A),
498+
"my_path": typed_state.encode("/path/to/1"),
499+
}
500+
},
501+
{
502+
"simple_types": {
503+
"my_int": 2,
504+
"my_enum": typed_state.encode(MyEnum.B),
505+
"my_path": typed_state.encode("/path/to/2"),
506+
}
507+
},
508+
]
509+
510+
assert state[typed_state.name.nested_dict] == {
511+
"3": {
512+
"simple_types": {
513+
"my_int": 3,
514+
"my_enum": typed_state.encode(MyEnum.C),
515+
"my_path": typed_state.encode("/path/to/3"),
516+
}
517+
},
518+
"4": {
519+
"simple_types": {
520+
"my_int": 4,
521+
"my_enum": typed_state.encode(MyEnum.A),
522+
"my_path": typed_state.encode("/path/to/4"),
523+
}
524+
},
525+
}

trame_server/utils/typed_state.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,12 @@ def __init__(self, encoders: list[IStateEncoderDecoder] | None = None):
115115
self._encoders = encoders or [DefaultEncoderDecoder()]
116116

117117
def encode(self, obj):
118+
if is_dataclass(obj):
119+
return {
120+
field.name: self.encode(getattr(obj, field.name))
121+
for field in fields(obj)
122+
}
123+
118124
if isinstance(obj, dict):
119125
return {self.encode(key): self.encode(value) for key, value in obj.items()}
120126

@@ -156,6 +162,7 @@ def _try_decode(self, obj, obj_type: type):
156162

157163
def _decode_strategies(self) -> list[Callable[[Any, type], Any]]:
158164
return [
165+
self._decode_dataclass,
159166
self._decode_union,
160167
self._decode_dict,
161168
self._decode_iterable,
@@ -196,6 +203,17 @@ def _decode_union(self, obj, obj_type: type):
196203
return val
197204
return self.failed_serialization()
198205

206+
def _decode_dataclass(self, obj, obj_type: type):
207+
if not is_dataclass(obj_type):
208+
return self.failed_serialization()
209+
210+
field_types = get_type_hints(obj_type)
211+
decoded_dict = {
212+
field.name: self._try_decode(obj.get(field.name), field_types[field.name])
213+
for field in fields(obj_type)
214+
}
215+
return obj_type(**decoded_dict)
216+
199217
@classmethod
200218
def _is_union_type(cls, obj_type: type):
201219
return get_origin(obj_type) is Union or isinstance(obj_type, UnionType)

0 commit comments

Comments
 (0)