Skip to content

Commit 6e76b4c

Browse files
authored
[ty] Improve support for Callable type context (#23888)
Improves literal promotion and generic call inference that involve `Callable` type context. Resolves astral-sh/ty#3016.
1 parent 30e8e1c commit 6e76b4c

7 files changed

Lines changed: 164 additions & 112 deletions

File tree

crates/ty_python_semantic/resources/mdtest/assignment/annotations.md

Lines changed: 67 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -580,7 +580,7 @@ def _():
580580
reveal_type(x4) # revealed: X
581581
```
582582

583-
## Prefer the declared type of generic classes
583+
## Prefer the declared type of generic classes and callables
584584

585585
```toml
586586
[environment]
@@ -682,6 +682,38 @@ x1: X[int | None] = X()
682682
reveal_type(x1) # revealed: X[None]
683683
```
684684

685+
We also prefer the declared type of `Callable` parameters, which are in contravariant position:
686+
687+
```py
688+
from typing import Callable
689+
690+
type AnyToBool = Callable[[Any], bool]
691+
692+
def wrap[**P, T](f: Callable[P, T]) -> Callable[P, T]:
693+
return f
694+
695+
def make_callable[T](x: T) -> Callable[[T], bool]:
696+
raise NotImplementedError
697+
698+
def maybe_make_callable[T](x: T) -> Callable[[T], bool] | None:
699+
raise NotImplementedError
700+
701+
x1: Callable[[Any], bool] = make_callable(0)
702+
reveal_type(x1) # revealed: (Any, /) -> bool
703+
704+
x2: AnyToBool = make_callable(0)
705+
reveal_type(x2) # revealed: (Any, /) -> bool
706+
707+
x3: Callable[[list[Any]], bool] = make_callable([0])
708+
reveal_type(x3) # revealed: (list[Any], /) -> bool
709+
710+
x4: Callable[[Any], bool] = wrap(make_callable(0))
711+
reveal_type(x4) # revealed: (Any, /) -> bool
712+
713+
x5: Callable[[Any], bool] | None = maybe_make_callable(0)
714+
reveal_type(x5) # revealed: ((Any, /) -> bool) | None
715+
```
716+
685717
## Declared type preference sees through subtyping
686718

687719
```toml
@@ -775,33 +807,48 @@ python-version = "3.12"
775807
```
776808

777809
```py
778-
from typing import reveal_type, TypedDict
810+
from typing import reveal_type, Any, Callable, TypedDict
779811

780812
def identity[T](x: T) -> T:
781813
return x
782814

783-
def _(narrow: dict[str, str], target: list[str] | dict[str, str] | None):
815+
type Target = Any | list[str] | dict[str, str] | Callable[[str], None] | None
816+
817+
def _(narrow: dict[str, str], target: Target):
784818
target = identity(narrow)
785819
reveal_type(target) # revealed: dict[str, str]
786820

787-
def _(narrow: list[str], target: list[str] | dict[str, str] | None):
821+
def _(narrow: list[str], target: Target):
788822
target = identity(narrow)
789823
reveal_type(target) # revealed: list[str]
790824

791-
def _(narrow: list[str] | dict[str, str], target: list[str] | dict[str, str] | None):
825+
def _(narrow: Callable[[str], None], target: Target):
826+
target = identity(narrow)
827+
reveal_type(target) # revealed: (str, /) -> None
828+
829+
def _(narrow: list[str] | dict[str, str], target: Target):
792830
target = identity(narrow)
793831
reveal_type(target) # revealed: list[str] | dict[str, str]
794832

795833
class TD(TypedDict):
796834
x: int
797835

798-
def _(target: list[TD] | dict[str, TD] | None):
836+
type TargetWithTD = Any | list[TD] | dict[str, TD] | Callable[[TD], None] | None
837+
838+
def _(target: TargetWithTD):
799839
target = identity([{"x": 1}])
800840
reveal_type(target) # revealed: list[TD]
801841

802-
def _(target: list[TD] | dict[str, TD] | None):
842+
def _(target: TargetWithTD):
803843
target = identity({"x": {"x": 1}})
804844
reveal_type(target) # revealed: dict[str, TD]
845+
846+
def _(target: TargetWithTD):
847+
def make_callable[T](x: T) -> Callable[[T], None]:
848+
raise NotImplementedError
849+
850+
target = identity(make_callable({"x": 1}))
851+
reveal_type(target) # revealed: (TD, /) -> None
805852
```
806853

807854
## Prefer the inferred type of non-generic classes
@@ -886,7 +933,7 @@ def _(a: int, b: str, c: int | str):
886933
reveal_type(x10) # revealed: int | str | None
887934
```
888935

889-
## Assignability diagnostics ignore declared type of generic classes
936+
## Assignability diagnostics ignore declared type
890937

891938
```toml
892939
[environment]
@@ -912,19 +959,27 @@ class A(TypedDict):
912959
x2: list[A | bool] = [{"bar": 1}, 1]
913960
```
914961

915-
However, the declared type of generic classes should be ignored if the specialization is not
916-
solvable:
962+
However, the declared type should be ignored if the specialization is not solvable:
917963

918964
```py
965+
from typing import Any, Callable
966+
919967
def g[T](x: list[T]) -> T:
920968
return x[0]
921969

922970
def _(a: int | None):
923971
# error: [invalid-assignment] "Object of type `list[int | None]` is not assignable to `list[str]`"
924-
y1: list[str] = f(a)
972+
x1: list[str] = f(a)
925973

926974
# error: [invalid-assignment] "Object of type `int | None` is not assignable to `str`"
927-
y2: str = g(f(a))
975+
x2: str = g(f(a))
976+
977+
def make_callable[T](x: T) -> Callable[[T], bool]:
978+
raise NotImplementedError
979+
980+
def _(a: int | None):
981+
# error: [invalid-assignment] "Object of type `(int | None, /) -> bool` is not assignable to `(str, /) -> bool`"
982+
x1: Callable[[str], bool] = make_callable(a)
928983
```
929984

930985
## Forward annotation with unclosed string literal

crates/ty_python_semantic/resources/mdtest/generics/legacy/callables.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def outside_callable(t: T) -> Callable[[T], T]:
181181
# revealed: ty_extensions.GenericContext[T@outside_callable]
182182
reveal_type(generic_context(outside_callable))
183183

184-
# revealed: (Literal[1], /) -> Literal[1]
184+
# revealed: (int, /) -> int
185185
reveal_type(outside_callable(1))
186186
# revealed: None
187187
reveal_type(generic_context(outside_callable(1)))

crates/ty_python_semantic/resources/mdtest/generics/pep695/callables.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def outside_callable[T](t: T) -> Callable[[T], T]:
181181
# revealed: ty_extensions.GenericContext[T@outside_callable]
182182
reveal_type(generic_context(outside_callable))
183183

184-
# revealed: (Literal[1], /) -> Literal[1]
184+
# revealed: (int, /) -> int
185185
reveal_type(outside_callable(1))
186186
# revealed: None
187187
reveal_type(generic_context(outside_callable(1)))

crates/ty_python_semantic/resources/mdtest/promotion.md

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ We promote in non-covariant position in the return type of a generic function, o
9696
generic class:
9797

9898
```py
99+
from typing import Callable, Literal
100+
99101
class Bivariant[T]:
100102
def __init__(self, value: T): ...
101103

@@ -124,6 +126,8 @@ def f8[T](x: T) -> Invariant[T] | Covariant[T] | None: ...
124126
def f9[T](x: T) -> tuple[Invariant[T], Invariant[T]] | None: ...
125127
def f10[T, U](x: T, y: U) -> tuple[Invariant[T], Covariant[U]] | None: ...
126128
def f11[T, U](x: T, y: U) -> tuple[Invariant[Covariant[T] | None], Covariant[U]] | None: ...
129+
def f12[T](x: T) -> Callable[[T], bool] | None: ...
130+
def f13[T](x: T) -> Callable[[bool], Invariant[T]] | None: ...
127131

128132
reveal_type(Bivariant(1)) # revealed: Bivariant[Literal[1]]
129133
reveal_type(Covariant(1)) # revealed: Covariant[Literal[1]]
@@ -144,6 +148,9 @@ reveal_type(f9(1)) # revealed: tuple[Invariant[int], Invariant[int]] | None
144148

145149
reveal_type(f10(1, 1)) # revealed: tuple[Invariant[int], Covariant[Literal[1]]] | None
146150
reveal_type(f11(1, 1)) # revealed: tuple[Invariant[Covariant[int] | None], Covariant[Literal[1]]] | None
151+
152+
reveal_type(f12(1)) # revealed: ((int, /) -> bool) | None
153+
reveal_type(f13(1)) # revealed: ((bool, /) -> Invariant[int]) | None
147154
```
148155

149156
## Promotion is recursive
@@ -190,6 +197,7 @@ declared in a promotable position:
190197
```py
191198
from enum import Enum
192199
from typing import Sequence, Literal, LiteralString
200+
from typing import Callable
193201

194202
class Color(Enum):
195203
RED = "red"
@@ -274,6 +282,18 @@ reveal_type(x21) # revealed: X[Literal[1]]
274282

275283
x22: X[Literal[1]] | None = x([1])
276284
reveal_type(x22) # revealed: X[Literal[1]]
285+
286+
def make_callable[T](x: T) -> Callable[[T], bool]:
287+
raise NotImplementedError
288+
289+
def maybe_make_callable[T](x: T) -> Callable[[T], bool] | None:
290+
raise NotImplementedError
291+
292+
x23: Callable[[Literal[1]], bool] = make_callable(1)
293+
reveal_type(x23) # revealed: (Literal[1], /) -> bool
294+
295+
x24: Callable[[Literal[1]], bool] | None = maybe_make_callable(1)
296+
reveal_type(x24) # revealed: ((Literal[1], /) -> bool) | None
277297
```
278298

279299
## Literal annotations see through subtyping
@@ -403,7 +423,7 @@ later used in a promotable position:
403423

404424
```py
405425
from enum import Enum
406-
from typing import Literal
426+
from typing import Callable, Literal
407427

408428
def promote[T](x: T) -> list[T]:
409429
return [x]
@@ -449,6 +469,16 @@ class MyEnum(Enum):
449469
def _(x: Literal[MyEnum.A, MyEnum.B]):
450470
reveal_type(x) # revealed: Literal[MyEnum.A, MyEnum.B]
451471
reveal_type([x]) # revealed: list[Literal[MyEnum.A, MyEnum.B]]
472+
473+
def make_callable[T](x: T) -> Callable[[T], bool]:
474+
raise NotImplementedError
475+
476+
def maybe_make_callable[T](x: T) -> Callable[[T], bool] | None:
477+
raise NotImplementedError
478+
479+
def _(x: Literal[1]):
480+
reveal_type(make_callable(x)) # revealed: (Literal[1], /) -> bool
481+
reveal_type(maybe_make_callable(x)) # revealed: ((Literal[1], /) -> bool) | None
452482
```
453483

454484
Literal promotability is respected by unions:

0 commit comments

Comments
 (0)