Skip to content

Commit d80c46e

Browse files
authored
[ty] pass type context to sequence literals in binary operations (#24197)
Fixes astral-sh/ty#3002. This is a quick fix for this special case. A more general solution will be passing type context through generic method calls, with binary operations like these handled via their dunder methods.
1 parent 533da8f commit d80c46e

2 files changed

Lines changed: 85 additions & 13 deletions

File tree

crates/ty_python_semantic/resources/mdtest/bidirectional.md

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,56 @@ def f[T](x: T, cond: bool) -> T | list[T]:
4343

4444
l5: int | list[int] = f(1, True)
4545

46-
a: list[int] = [1, 2, *(3, 4, 5)]
47-
reveal_type(a) # revealed: list[int]
46+
x: list[int] = [1, 2, *(3, 4, 5)]
47+
reveal_type(x) # revealed: list[int]
4848

49-
b: list[list[int]] = [[1], [2], *([3], [4])]
50-
reveal_type(b) # revealed: list[list[int]]
49+
x: list[list[int]] = [[1], [2], *([3], [4])]
50+
reveal_type(x) # revealed: list[list[int]]
51+
52+
x: list[list[int | str]] = [[1], [2]] * 3
53+
reveal_type(x) # revealed: list[list[int | str]]
54+
55+
x: list[list[int | str]] = 3 * ([[1]] + [[2]])
56+
reveal_type(x) # revealed: list[list[int | str]]
57+
58+
x: list[int | str] = 3 * ["x" for _ in range(3)]
59+
reveal_type(x) # revealed: list[int | str]
60+
61+
# Tuple elements are inferred individually, but type context can prevent e.g. `int` widening.
62+
x: tuple[list[Literal[1]]] = (list1(1),)
63+
reveal_type(x) # revealed: tuple[list[Literal[1]]]
64+
65+
x: tuple[list[Literal[1]], ...] = (list1(1),) * 3
66+
reveal_type(x) # revealed: tuple[list[Literal[1]], ...]
67+
68+
x: tuple[list[Literal[1]], ...] = 3 * ((list1(1),) + (list1(1),))
69+
reveal_type(x) # revealed: tuple[list[Literal[1]], ...]
70+
71+
x: set[int | str] = {1, 2} | {3, 4}
72+
reveal_type(x) # revealed: set[int | str]
73+
74+
x: set[int | str] = {42 for _ in range(3)}
75+
reveal_type(x) # revealed: set[int | str]
76+
77+
x: dict[int | str, int | str] = {1: 2} | {3: 4}
78+
reveal_type(x) # revealed: dict[int | str, int | str]
79+
80+
x: dict[int | str, int | str] = {str(i): i for i in range(3)}
81+
reveal_type(x) # revealed: dict[int | str, int | str]
82+
83+
# TODO: We currently eagerly pass type context to collection literals on either side of a binary
84+
# operator. That makes the cases above work, but it's not generally sound. For example, it gives the
85+
# wrong result in this case.
86+
class X:
87+
def __add__(self, _: list[int]) -> list[int | str]:
88+
return []
89+
90+
# error: [unsupported-operator] "Operator `+` is not supported between objects of type `X` and `list[int | str]`"
91+
x: list[int | str] = X() + [1]
92+
93+
# TODO: We also don't yet support generic function calls like this.
94+
# error: [invalid-assignment] "Object of type `list[int]` is not assignable to `list[int | str]`"
95+
x: list[int | str] = list1(42) * 3
5196
```
5297

5398
`typed_dict.py`:
@@ -88,6 +133,8 @@ reveal_type(d4_invalid_dict) # revealed: TD
88133
d5_literal: dict[Hashable, Callable[..., object]] = {"x": lambda: 1}
89134
d5_dict: dict[Hashable, Callable[..., object]] = dict(x=lambda: 1)
90135

136+
d6_dict: TD = {"x": 1} | {"x": 2}
137+
91138
def return_literal() -> TD:
92139
return {"x": 1}
93140

crates/ty_python_semantic/src/types/infer/builder/binary_expressions.rs

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,11 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
4040
node_index: _,
4141
} = binary;
4242

43-
let (left_ty, right_ty) = match self.infer_binary_expression_operand_types(left, *op, right)
44-
{
45-
BinaryExpressionOperandTypes::TypedDictResult(ty) => return ty,
46-
BinaryExpressionOperandTypes::Inferred(left_ty, right_ty) => (left_ty, right_ty),
47-
};
43+
let (left_ty, right_ty) =
44+
match self.infer_binary_expression_operand_types(left, *op, right, tcx) {
45+
BinaryExpressionOperandTypes::TypedDictResult(ty) => return ty,
46+
BinaryExpressionOperandTypes::Inferred(left_ty, right_ty) => (left_ty, right_ty),
47+
};
4848

4949
self.infer_binary_expression_type(binary.into(), false, left_ty, right_ty, *op)
5050
.unwrap_or_else(|| {
@@ -108,12 +108,37 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
108108
left: &ast::Expr,
109109
op: ast::Operator,
110110
right: &ast::Expr,
111+
tcx: TypeContext<'db>,
111112
) -> BinaryExpressionOperandTypes<'db> {
113+
// As a special case, pass `tcx` to binary operands that are collection literals/displays.
114+
// Note that it's not correct to pass it to all binary operands, for example:
115+
// ```
116+
// x: list[str] = ["x"] * 3
117+
// ```
118+
// It doesn't make sense to pass the list type context to the `3` expression. It wouldn't
119+
// have any effect in this case, but it could in more complicated cases.
120+
// TODO: When we support passing `tcx` through generic method calls, we can remove this
121+
// special case and handle the relevant dunder method instead.
122+
let operand_tcx = |expr: &ast::Expr| -> TypeContext<'db> {
123+
match expr {
124+
ast::Expr::List(_)
125+
| ast::Expr::Tuple(_)
126+
| ast::Expr::Set(_)
127+
| ast::Expr::Dict(_)
128+
| ast::Expr::ListComp(_)
129+
| ast::Expr::SetComp(_)
130+
| ast::Expr::DictComp(_) => tcx,
131+
// Also pass `tcx` to nested binary expressions.
132+
ast::Expr::BinOp(_) => tcx,
133+
_ => TypeContext::default(),
134+
}
135+
};
136+
112137
// When a dict literal is `|`'d with a TypedDict, infer the non-literal side first
113138
// so we can use bidirectional inference on the literal before calling the synthesized
114139
// `__or__`/`__ror__` method on the TypedDict side.
115140
if op == ast::Operator::BitOr && matches!(left, ast::Expr::Dict(_)) {
116-
let right_ty = self.infer_expression(right, TypeContext::default());
141+
let right_ty = self.infer_expression(right, operand_tcx(right));
117142
if let Type::TypedDict(typed_dict) = right_ty
118143
&& let Some(ty) = self.try_typed_dict_pep_584_dunder(
119144
left,
@@ -128,12 +153,12 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
128153
// If the TypedDict update path rejects the literal, fall back to ordinary inference
129154
// even though that means re-inferring the literal without TypedDict context.
130155
return BinaryExpressionOperandTypes::Inferred(
131-
self.infer_expression(left, TypeContext::default()),
156+
self.infer_expression(left, operand_tcx(left)),
132157
right_ty,
133158
);
134159
}
135160

136-
let left_ty = self.infer_expression(left, TypeContext::default());
161+
let left_ty = self.infer_expression(left, operand_tcx(left));
137162
if op == ast::Operator::BitOr
138163
&& let Type::TypedDict(typed_dict) = left_ty
139164
&& matches!(right, ast::Expr::Dict(_))
@@ -149,7 +174,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
149174

150175
BinaryExpressionOperandTypes::Inferred(
151176
left_ty,
152-
self.infer_expression(right, TypeContext::default()),
177+
self.infer_expression(right, operand_tcx(right)),
153178
)
154179
}
155180

0 commit comments

Comments
 (0)