@@ -40,11 +40,11 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
4040 node_index : _,
4141 } = binary;
4242
43- let ( left_ty, right_ty) = match self . infer_binary_expression_operand_types ( left , * op , right )
44- {
45- BinaryExpressionOperandTypes :: TypedDictResult ( ty) => return ty,
46- BinaryExpressionOperandTypes :: Inferred ( left_ty, right_ty) => ( left_ty, right_ty) ,
47- } ;
43+ let ( left_ty, right_ty) =
44+ match self . infer_binary_expression_operand_types ( left , * op , right , tcx ) {
45+ BinaryExpressionOperandTypes :: TypedDictResult ( ty) => return ty,
46+ BinaryExpressionOperandTypes :: Inferred ( left_ty, right_ty) => ( left_ty, right_ty) ,
47+ } ;
4848
4949 self . infer_binary_expression_type ( binary. into ( ) , false , left_ty, right_ty, * op)
5050 . unwrap_or_else ( || {
@@ -108,12 +108,37 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
108108 left : & ast:: Expr ,
109109 op : ast:: Operator ,
110110 right : & ast:: Expr ,
111+ tcx : TypeContext < ' db > ,
111112 ) -> BinaryExpressionOperandTypes < ' db > {
113+ // As a special case, pass `tcx` to binary operands that are collection literals/displays.
114+ // Note that it's not correct to pass it to all binary operands, for example:
115+ // ```
116+ // x: list[str] = ["x"] * 3
117+ // ```
118+ // It doesn't make sense to pass the list type context to the `3` expression. It wouldn't
119+ // have any effect in this case, but it could in more complicated cases.
120+ // TODO: When we support passing `tcx` through generic method calls, we can remove this
121+ // special case and handle the relevant dunder method instead.
122+ let operand_tcx = |expr : & ast:: Expr | -> TypeContext < ' db > {
123+ match expr {
124+ ast:: Expr :: List ( _)
125+ | ast:: Expr :: Tuple ( _)
126+ | ast:: Expr :: Set ( _)
127+ | ast:: Expr :: Dict ( _)
128+ | ast:: Expr :: ListComp ( _)
129+ | ast:: Expr :: SetComp ( _)
130+ | ast:: Expr :: DictComp ( _) => tcx,
131+ // Also pass `tcx` to nested binary expressions.
132+ ast:: Expr :: BinOp ( _) => tcx,
133+ _ => TypeContext :: default ( ) ,
134+ }
135+ } ;
136+
112137 // When a dict literal is `|`'d with a TypedDict, infer the non-literal side first
113138 // so we can use bidirectional inference on the literal before calling the synthesized
114139 // `__or__`/`__ror__` method on the TypedDict side.
115140 if op == ast:: Operator :: BitOr && matches ! ( left, ast:: Expr :: Dict ( _) ) {
116- let right_ty = self . infer_expression ( right, TypeContext :: default ( ) ) ;
141+ let right_ty = self . infer_expression ( right, operand_tcx ( right ) ) ;
117142 if let Type :: TypedDict ( typed_dict) = right_ty
118143 && let Some ( ty) = self . try_typed_dict_pep_584_dunder (
119144 left,
@@ -128,12 +153,12 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
128153 // If the TypedDict update path rejects the literal, fall back to ordinary inference
129154 // even though that means re-inferring the literal without TypedDict context.
130155 return BinaryExpressionOperandTypes :: Inferred (
131- self . infer_expression ( left, TypeContext :: default ( ) ) ,
156+ self . infer_expression ( left, operand_tcx ( left ) ) ,
132157 right_ty,
133158 ) ;
134159 }
135160
136- let left_ty = self . infer_expression ( left, TypeContext :: default ( ) ) ;
161+ let left_ty = self . infer_expression ( left, operand_tcx ( left ) ) ;
137162 if op == ast:: Operator :: BitOr
138163 && let Type :: TypedDict ( typed_dict) = left_ty
139164 && matches ! ( right, ast:: Expr :: Dict ( _) )
@@ -149,7 +174,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
149174
150175 BinaryExpressionOperandTypes :: Inferred (
151176 left_ty,
152- self . infer_expression ( right, TypeContext :: default ( ) ) ,
177+ self . infer_expression ( right, operand_tcx ( right ) ) ,
153178 )
154179 }
155180
0 commit comments