Skip to content

Commit 5ba83c8

Browse files
committed
[ty] Support inheriting from functional TypedDict
1 parent df581ea commit 5ba83c8

4 files changed

Lines changed: 91 additions & 11 deletions

File tree

crates/ty_python_semantic/resources/mdtest/typed_dict.md

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2009,6 +2009,33 @@ emp_invalid1 = Employee(department="HR")
20092009
emp_invalid2 = Employee(id=3)
20102010
```
20112011

2012+
## Class-based inheritance from functional `TypedDict`
2013+
2014+
Class-based TypedDicts can inherit from functional TypedDicts:
2015+
2016+
```py
2017+
from typing import TypedDict
2018+
2019+
Base = TypedDict("Base", {"a": int}, total=False)
2020+
2021+
class Child(Base):
2022+
b: str
2023+
c: list[int]
2024+
2025+
child1 = Child(b="hello", c=[1, 2, 3])
2026+
child2 = Child(a=1, b="world", c=[])
2027+
2028+
reveal_type(child1["a"]) # revealed: int
2029+
reveal_type(child1["b"]) # revealed: str
2030+
reveal_type(child1["c"]) # revealed: list[int]
2031+
2032+
# error: [missing-typed-dict-key] "Missing required key 'b' in TypedDict `Child` constructor"
2033+
bad_child1 = Child(c=[1])
2034+
2035+
# error: [missing-typed-dict-key] "Missing required key 'c' in TypedDict `Child` constructor"
2036+
bad_child2 = Child(b="test")
2037+
```
2038+
20122039
## Generic `TypedDict`
20132040

20142041
`TypedDict`s can also be generic.
@@ -2611,6 +2638,9 @@ def f():
26112638

26122639
# fine
26132640
MyFunctionalTypedDict = TypedDict("MyFunctionalTypedDict", {"not-an-identifier": Required[int]})
2641+
2642+
class FunctionalTypedDictSubclass(MyFunctionalTypedDict):
2643+
y: NotRequired[int] # fine
26142644
```
26152645

26162646
### Nested `Required` and `NotRequired`
@@ -3650,6 +3680,18 @@ class Child(Base):
36503680
y: str
36513681
```
36523682

3683+
The functional `TypedDict` syntax also triggers this error:
3684+
3685+
```py
3686+
from dataclasses import dataclass
3687+
from typing import TypedDict
3688+
3689+
@dataclass
3690+
# error: [invalid-dataclass]
3691+
class Foo(TypedDict("Foo", {"x": int, "y": str})):
3692+
pass
3693+
```
3694+
36533695
## Class header validation
36543696

36553697
<!-- snapshot-diagnostics -->

crates/ty_python_semantic/src/types/class.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -907,6 +907,11 @@ impl<'db> ClassType<'db> {
907907
self.is_known(db, KnownClass::Object)
908908
}
909909

910+
/// Return `true` if this class is a `TypedDict`.
911+
pub(crate) fn is_typed_dict(self, db: &'db dyn Db) -> bool {
912+
self.class_literal(db).is_typed_dict(db)
913+
}
914+
910915
pub(super) fn apply_type_mapping_impl<'a>(
911916
self,
912917
db: &'db dyn Db,

crates/ty_python_semantic/src/types/class/static_literal.rs

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ use crate::{
3434
call::{CallError, CallErrorKind},
3535
callable::CallableTypeKind,
3636
class::{
37-
ClassMemberResult, CodeGeneratorKind, DisjointBase, Field, FieldKind,
38-
InstanceMemberResult, MetaclassError, MetaclassErrorKind, MethodDecorator, MroLookup,
39-
NamedTupleField, SlotsKind, synthesize_namedtuple_class_member,
37+
ClassMemberResult, CodeGeneratorKind, DisjointBase, DynamicTypedDictLiteral, Field,
38+
FieldKind, InstanceMemberResult, MetaclassError, MetaclassErrorKind, MethodDecorator,
39+
MroLookup, NamedTupleField, SlotsKind, synthesize_namedtuple_class_member,
4040
},
4141
context::InferContext,
4242
declaration_type, definition_expression_type, determine_upper_bound,
@@ -1638,6 +1638,11 @@ impl<'db> StaticClassLiteral<'db> {
16381638
specialization: Option<Specialization<'db>>,
16391639
field_policy: CodeGeneratorKind<'db>,
16401640
) -> FxIndexMap<Name, Field<'db>> {
1641+
enum FieldSource<'db> {
1642+
Static(StaticClassLiteral<'db>, Option<Specialization<'db>>),
1643+
DynamicTypedDict(DynamicTypedDictLiteral<'db>),
1644+
}
1645+
16411646
if field_policy == CodeGeneratorKind::NamedTuple {
16421647
// NamedTuples do not allow multiple inheritance, so it is sufficient to enumerate the
16431648
// fields of this class only.
@@ -1648,15 +1653,43 @@ impl<'db> StaticClassLiteral<'db> {
16481653
.rev()
16491654
.filter_map(|superclass| {
16501655
let class = superclass.into_class()?;
1651-
// Dynamic classes don't have fields (no class body).
1652-
let (class_literal, specialization) = class.static_class_literal(db)?;
1653-
if field_policy.matches(db, class_literal.into(), specialization) {
1654-
Some((class_literal, specialization))
1655-
} else {
1656-
None
1656+
1657+
if let Some((class_literal, specialization)) = class.static_class_literal(db) {
1658+
if field_policy.matches(db, class_literal.into(), specialization) {
1659+
return Some(FieldSource::Static(class_literal, specialization));
1660+
}
1661+
}
1662+
1663+
if field_policy == CodeGeneratorKind::TypedDict
1664+
&& let ClassLiteral::DynamicTypedDict(typeddict) = class.class_literal(db)
1665+
{
1666+
return Some(FieldSource::DynamicTypedDict(typeddict));
16571667
}
1668+
1669+
None
1670+
})
1671+
.flat_map(|source| match source {
1672+
FieldSource::Static(class, specialization) => {
1673+
class.own_fields(db, specialization, field_policy)
1674+
}
1675+
FieldSource::DynamicTypedDict(typeddict) => typeddict
1676+
.items(db)
1677+
.iter()
1678+
.map(|(name, td_field)| {
1679+
(
1680+
name.clone(),
1681+
Field {
1682+
declared_ty: td_field.declared_ty,
1683+
kind: FieldKind::TypedDict {
1684+
is_required: td_field.is_required(),
1685+
is_read_only: td_field.is_read_only(),
1686+
},
1687+
first_declaration: td_field.first_declaration(),
1688+
},
1689+
)
1690+
})
1691+
.collect(),
16581692
})
1659-
.flat_map(|(class, specialization)| class.own_fields(db, specialization, field_policy))
16601693
// KW_ONLY sentinels are markers, not real fields. Exclude them so
16611694
// they cannot shadow an inherited field with the same name.
16621695
.filter(|(_, field)| !field.is_kw_only_sentinel(db))

crates/ty_python_semantic/src/types/infer/builder.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7374,7 +7374,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
73747374

73757375
// Validate `TypedDict` constructor calls after argument type inference.
73767376
if let Some(class) = class
7377-
&& class.class_literal(self.db()).is_typed_dict(self.db())
7377+
&& class.is_typed_dict(self.db())
73787378
{
73797379
validate_typed_dict_constructor(
73807380
&self.context,

0 commit comments

Comments
 (0)