Skip to content

Commit c6869a1

Browse files
authored
Merge pull request #1568 from google/google_sync
Google sync
2 parents 2e30cee + bb0736a commit c6869a1

9 files changed

Lines changed: 124 additions & 46 deletions

File tree

pytype/matcher.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -456,12 +456,17 @@ def match_var_against_type(self, var, other_type, subst, view):
456456

457457
def _match_type_param_against_type_param(self, t1, t2, subst, view):
458458
"""Match a TypeVar against another TypeVar."""
459-
if t1.full_name == "typing.Self" and not t1.bound:
460-
# We're matching a Self instance before it's been bound to its containing
461-
# class. We know it should be bound but not to what, so `Any` is the best
462-
# we can do.
463-
t1 = t1.copy()
464-
t1.bound = self.ctx.convert.unsolvable
459+
if t1.full_name == "typing.Self":
460+
if t2.full_name == "typing.Self":
461+
# Self always matches itself. We check for this explicitly because Self
462+
# instances may have their bounds set to incompatible classes.
463+
return subst
464+
elif not t1.bound:
465+
# We're matching a Self instance before it's been bound to its
466+
# containing class. We know it should be bound but not to what, so `Any`
467+
# is the best we can do.
468+
t1 = t1.copy()
469+
t1.bound = self.ctx.convert.unsolvable
465470
if t2.constraints:
466471
assert not t2.bound # constraints and bounds are mutually exclusive
467472
# We only check the constraints for t1, not the bound. We wouldn't know
@@ -1461,7 +1466,7 @@ def _match_dict_against_typed_dict(
14611466
for k, v in left.pyval.items():
14621467
if k not in fields:
14631468
continue
1464-
typ = abstract_utils.get_atomic_value(fields[k])
1469+
typ = fields[k]
14651470
match_result = self.compute_one_match(v, typ)
14661471
if not match_result.success:
14671472
bad.append((k, match_result.bad_matches))

pytype/output.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,10 +1016,8 @@ def _typed_dict_to_def(self, node, v, name):
10161016
keywords.append(("total", pytd.Literal(False)))
10171017
bases = (pytd.NamedType("typing.TypedDict"),)
10181018
constants = []
1019-
for k, var in v.props.fields.items():
1020-
typ = pytd_utils.JoinTypes(
1021-
self.value_instance_to_pytd_type(node, p, None, set(), {})
1022-
for p in var.data)
1019+
for k, val in v.props.fields.items():
1020+
typ = self.value_instance_to_pytd_type(node, val, None, set(), {})
10231021
if v.props.total and k not in v.props.required:
10241022
typ = pytd.GenericType(pytd.NamedType("typing.NotRequired"), (typ,))
10251023
elif not v.props.total and k in v.props.required:

pytype/overlays/dataclass_overlay.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def decorate(self, node, cls):
9797
continue
9898
kind = ""
9999
init = True
100-
kw_only = False
100+
kw_only = sticky_kwonly
101101
assert typ
102102
if match_classvar(typ):
103103
continue
@@ -112,8 +112,8 @@ def decorate(self, node, cls):
112112
field = orig.data[0]
113113
orig = field.default
114114
init = field.init
115-
if self.ctx.python_version >= (3, 10):
116-
kw_only = sticky_kwonly if field.kw_only is None else field.kw_only
115+
if field.kw_only is not None:
116+
kw_only = field.kw_only
117117

118118
if orig and orig.data == [self.ctx.convert.none]:
119119
# vm._apply_annotation mostly takes care of checking that the default

pytype/overlays/typed_dict.py

Lines changed: 28 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import dataclasses
44

5-
from typing import Any, Dict, Optional, Set
5+
from typing import Dict, Optional, Set
66

77
from pytype.abstract import abstract
88
from pytype.abstract import abstract_utils
@@ -34,7 +34,7 @@ class TypedDictProperties:
3434
"""Collection of typed dict properties passed between various stages."""
3535

3636
name: str
37-
fields: Dict[str, Any]
37+
fields: Dict[str, abstract.BaseValue]
3838
required: Set[str]
3939
total: bool
4040

@@ -48,27 +48,15 @@ def optional(self):
4848

4949
def add(self, k, v, total):
5050
"""Adds key and value."""
51-
values = []
52-
all_requiredness = set()
53-
for value in v.data:
54-
req = _is_required(value)
55-
if req is None:
56-
values.append(value)
57-
all_requiredness.add(None)
58-
elif isinstance(value, abstract.ParameterizedClass):
59-
values.append(value.formal_type_parameters[abstract_utils.T])
60-
all_requiredness.add(req)
61-
else:
62-
values.append(value.ctx.convert.unsolvable)
63-
all_requiredness.add(req)
64-
if (len(all_requiredness) == 1 and
65-
(requiredness := next(iter(all_requiredness))) is not None):
66-
final_v = v.program.NewVariable(values, [], v.program.entrypoint)
67-
required = requiredness
51+
req = _is_required(v)
52+
if req is None:
53+
value = v
54+
elif isinstance(v, abstract.ParameterizedClass):
55+
value = v.formal_type_parameters[abstract_utils.T]
6856
else:
69-
final_v = v
70-
required = total
71-
self.fields[k] = final_v # pylint: disable=unsupported-assignment-operation
57+
value = v.ctx.convert.unsolvable
58+
required = total if req is None else req
59+
self.fields[k] = value # pylint: disable=unsupported-assignment-operation
7260
if required:
7361
self.required.add(k)
7462

@@ -122,7 +110,12 @@ def _extract_args(self, args):
122110
name=name, fields={}, required=set(), total=total)
123111
# Force Required/NotRequired evaluation
124112
for k, v in fields.items():
125-
props.add(k, v, total)
113+
try:
114+
value = abstract_utils.get_atomic_value(v)
115+
except abstract_utils.ConversionError:
116+
self.ctx.errorlog.ambiguous_annotation(self.ctx.vm.frames, v.data, k)
117+
value = self.ctx.convert.unsolvable
118+
props.add(k, value, total)
126119
return props
127120

128121
def _validate_bases(self, cls_name, bases):
@@ -182,8 +175,14 @@ def make_class(self, node, bases, f_locals, total):
182175
ordering=classgen.Ordering.FIRST_ANNOTATE,
183176
ctx=self.ctx)
184177
for k, local in cls_locals.items():
185-
assert local.typ
186-
props.add(k, local.typ, total)
178+
var = local.typ
179+
assert var
180+
try:
181+
typ = abstract_utils.get_atomic_value(var)
182+
except abstract_utils.ConversionError:
183+
self.ctx.errorlog.ambiguous_annotation(self.ctx.vm.frames, var.data, k)
184+
typ = self.ctx.convert.unsolvable
185+
props.add(k, typ, total)
187186

188187
# Process base classes and generate the __init__ signature.
189188
self._validate_bases(cls_name, bases)
@@ -207,7 +206,7 @@ def make_class_from_pyi(self, cls_name, pytd_cls):
207206
name=name, fields={}, required=set(), total=total)
208207

209208
for c in pytd_cls.constants:
210-
typ = self.ctx.convert.constant_to_var(c.type)
209+
typ = self.ctx.convert.constant_to_value(c.type)
211210
props.add(c.name, typ, total)
212211

213212
# Process base classes and generate the __init__ signature.
@@ -239,8 +238,7 @@ def _make_init(self, props):
239238
sig = function.Signature.from_param_names(
240239
f"{props.name}.__init__", props.fields.keys(),
241240
kind=pytd.ParameterKind.KWONLY)
242-
sig.annotations = {k: abstract_utils.get_atomic_value(v)
243-
for k, v in props.fields.items()}
241+
sig.annotations = dict(props.fields)
244242
sig.defaults = {k: self.ctx.new_unsolvable(self.ctx.root_node)
245243
for k in props.optional}
246244
return abstract.SimpleFunction(sig, self.ctx)
@@ -256,8 +254,7 @@ def _new_instance(self, container, node, args):
256254
def instantiate_value(self, node, container):
257255
args = function.Args(())
258256
for name, typ in self.props.fields.items():
259-
args.namedargs[name] = self.ctx.join_variables(
260-
node, [t.instantiate(node) for t in typ.data])
257+
args.namedargs[name] = typ.instantiate(node)
261258
return self._new_instance(container, node, args)
262259

263260
def instantiate(self, node, container=None):
@@ -301,7 +298,7 @@ def _check_str_key(self, name):
301298

302299
def _check_str_key_value(self, node, name, value_var):
303300
self._check_str_key(name)
304-
typ = abstract_utils.get_atomic_value(self.fields[name])
301+
typ = self.fields[name]
305302
bad = self.ctx.matcher(node).compute_one_match(value_var, typ).bad_matches
306303
for match in bad:
307304
self.ctx.errorlog.annotation_type_mismatch(

pytype/overriding_checks.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,8 @@ def is_subtype(this_type, that_type):
414414
"""Return True iff this_type is a subclass of that_type."""
415415
if this_type == ctx.convert.never:
416416
return True # Never is the bottom type, so it matches everything
417-
this_type_instance = this_type.instantiate(ctx.root_node, None)
417+
this_type_instance = this_type.instantiate(
418+
ctx.root_node, container=abstract_utils.DUMMY_CONTAINER)
418419
return matcher.compute_one_match(this_type_instance, that_type).success
419420

420421
check_result = (

pytype/tests/test_dataclasses.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -761,6 +761,26 @@ class A:
761761
def __init__(self, a1: int, a3: int, *, a2: int = ...) -> None: ...
762762
""")
763763

764+
@test_utils.skipBeforePy((3, 10), "KW_ONLY is new in 3.10")
765+
def test_kwonly_and_nonfield_default(self):
766+
ty = self.Infer("""
767+
import dataclasses
768+
@dataclasses.dataclass
769+
class C:
770+
_: dataclasses.KW_ONLY
771+
x: int = 0
772+
y: str
773+
""")
774+
self.assertTypesMatchPytd(ty, """
775+
import dataclasses
776+
@dataclasses.dataclass
777+
class C:
778+
x: int = ...
779+
y: str
780+
_: dataclasses.KW_ONLY
781+
def __init__(self, *, x: int = ..., y: str) -> None: ...
782+
""")
783+
764784
def test_star_import(self):
765785
with self.DepTree([("foo.pyi", """
766786
import dataclasses

pytype/tests/test_flax_overlay.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,34 @@ def __init__(
266266
def replace(self: _TBaz, **kwargs) -> _TBaz: ...
267267
""")
268268

269+
@test_utils.skipBeforePy((3, 10), "KW_ONLY is new in 3.10")
270+
def test_kwonly(self):
271+
with test_utils.Tempdir() as d:
272+
self._setup_linen_pyi(d)
273+
ty = self.Infer("""
274+
import dataclasses
275+
from flax import linen as nn
276+
class C(nn.Module):
277+
_: dataclasses.KW_ONLY
278+
x: int = 0
279+
y: str
280+
""", pythonpath=[d.path])
281+
self.assertTypesMatchPytd(ty, """
282+
import dataclasses
283+
from flax import linen as nn
284+
from typing import Any, TypeVar
285+
286+
_TC = TypeVar('_TC', bound=C)
287+
288+
@dataclasses.dataclass
289+
class C(nn.module.Module):
290+
x: int = ...
291+
y: str
292+
_: dataclasses.KW_ONLY
293+
def __init__(self, *, x: int = ..., y: str, name: str = ..., parent: Any = ...) -> None: ...
294+
def replace(self: _TC, **kwargs) -> _TC: ...
295+
""")
296+
269297

270298
if __name__ == "__main__":
271299
test_base.main()

pytype/tests/test_typed_dict.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,14 @@ def f() -> TD:
385385
return __any_object__
386386
""")
387387

388+
def test_duplicate_key(self):
389+
self.CheckWithErrors("""
390+
from typing_extensions import TypedDict
391+
class TD(TypedDict): # invalid-annotation
392+
x: int
393+
x: str
394+
""")
395+
388396

389397
class TypedDictFunctionalTest(test_base.BaseTest):
390398
"""Tests for typing.TypedDict functional constructor."""
@@ -458,6 +466,16 @@ class X(TypedDict, total=False):
458466
name: str
459467
""")
460468

469+
def test_ambiguous_field_type(self):
470+
self.CheckWithErrors("""
471+
from typing_extensions import TypedDict
472+
if __random__:
473+
v = str
474+
else:
475+
v = int
476+
X = TypedDict('X', {'k': v}) # invalid-annotation
477+
""")
478+
461479

462480
_SINGLE = """
463481
from typing import TypedDict

pytype/tests/test_typing_self.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,17 @@ def f(self) -> Self:
281281
return self
282282
""")
283283

284+
def test_signature_compatibility(self):
285+
self.Check("""
286+
from typing_extensions import Self
287+
class Parent:
288+
def add(self, other: Self) -> Self:
289+
return self
290+
class Child(Parent):
291+
def add(self, other: Self) -> Self:
292+
return self
293+
""")
294+
284295

285296
class SelfPyiTest(test_base.BaseTest):
286297
"""Tests for typing.Self usage in type stubs."""

0 commit comments

Comments
 (0)