Skip to content

Commit 65962f6

Browse files
authored
[ty] implement cycle normalization for more types to prevent too-many-cycle panics (#24061)
## Summary Fixes astral-sh/ty#3080 ## Test Plan new corpus test
1 parent 677f6c9 commit 65962f6

4 files changed

Lines changed: 220 additions & 2 deletions

File tree

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Regression test for https://github.com/astral-sh/ty/issues/3080
2+
3+
# To reproduce the bug, deferred evaluation of type annotations must be applied.
4+
from __future__ import annotations
5+
6+
from typing import Generic, Protocol, Self, TypeVar, overload
7+
8+
S = TypeVar("S")
9+
T = TypeVar("T")
10+
11+
12+
class Unit(Protocol):
13+
def __mul__(self, other: S | Quantity[S]): ...
14+
15+
16+
class Vector(Protocol): ...
17+
18+
19+
class Quantity(Generic[T], Protocol):
20+
@overload
21+
def __mul__(self, other: Unit | Quantity[S]): ...
22+
23+
@overload
24+
def __mul__(self, other: Vector) -> Vector: ...

crates/ty_python_semantic/src/types/function.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1136,7 +1136,12 @@ impl<'db> FunctionType<'db> {
11361136
///
11371137
/// Were this not a salsa query, then the calling query
11381138
/// would depend on the function's AST and rerun for every change in that file.
1139-
#[salsa::tracked(returns(ref), cycle_initial=|_, _, _| CallableSignature::single(Signature::bottom()), heap_size=ruff_memory_usage::heap_size)]
1139+
#[salsa::tracked(
1140+
returns(ref),
1141+
cycle_initial=|_, _, _| CallableSignature::single(Signature::bottom()),
1142+
cycle_fn=|db, cycle, previous, value: CallableSignature<'db>, _| value.cycle_normalized(db, previous, cycle),
1143+
heap_size=ruff_memory_usage::heap_size,
1144+
)]
11401145
pub(crate) fn signature(self, db: &'db dyn Db) -> CallableSignature<'db> {
11411146
self.updated_signature(db)
11421147
.cloned()

crates/ty_python_semantic/src/types/protocol_class.rs

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,24 @@ impl<'db> ProtocolInterface<'db> {
256256
Self::new(db, BTreeMap::default())
257257
}
258258

259+
fn cycle_normalized(self, db: &'db dyn Db, previous: Self, cycle: &salsa::Cycle) -> Self {
260+
let prev_inner = previous.inner(db);
261+
let curr_inner = self.inner(db);
262+
263+
let members: BTreeMap<_, _> = curr_inner
264+
.iter()
265+
.map(|(name, curr_data)| {
266+
let normalized = if let Some(prev_data) = prev_inner.get(name) {
267+
curr_data.cycle_normalized(db, prev_data, cycle)
268+
} else {
269+
curr_data.clone()
270+
};
271+
(name.clone(), normalized)
272+
})
273+
.collect();
274+
Self::new(db, members)
275+
}
276+
259277
pub(super) fn members<'a>(
260278
self,
261279
db: &'db dyn Db,
@@ -404,6 +422,14 @@ pub(super) struct ProtocolMemberData<'db> {
404422
}
405423

406424
impl<'db> ProtocolMemberData<'db> {
425+
fn cycle_normalized(&self, db: &'db dyn Db, previous: &Self, cycle: &salsa::Cycle) -> Self {
426+
Self {
427+
kind: self.kind.cycle_normalized(db, &previous.kind, cycle),
428+
qualifiers: self.qualifiers,
429+
definition: self.definition,
430+
}
431+
}
432+
407433
fn recursive_type_normalized_impl(
408434
&self,
409435
db: &'db dyn Db,
@@ -509,6 +535,38 @@ enum ProtocolMemberKind<'db> {
509535
}
510536

511537
impl<'db> ProtocolMemberKind<'db> {
538+
fn cycle_normalized(&self, db: &'db dyn Db, previous: &Self, cycle: &salsa::Cycle) -> Self {
539+
match (self, previous) {
540+
(Self::Method(curr), Self::Method(prev)) => {
541+
debug_assert_eq!(curr.kind(db), prev.kind(db));
542+
let normalized =
543+
curr.signatures(db)
544+
.cycle_normalized(db, prev.signatures(db), cycle);
545+
Self::Method(CallableType::new(db, normalized, curr.kind(db)))
546+
}
547+
(Self::Property(curr), Self::Property(prev)) => {
548+
let getter = match (curr.getter(db), prev.getter(db)) {
549+
(Some(curr), Some(prev)) => Some(curr.cycle_normalized(db, prev, cycle)),
550+
(Some(curr), None) => Some(curr.recursive_type_normalized(db, cycle)),
551+
(None, _) => None,
552+
};
553+
let setter = match (curr.setter(db), prev.setter(db)) {
554+
(Some(curr), Some(prev)) => Some(curr.cycle_normalized(db, prev, cycle)),
555+
(Some(curr), None) => Some(curr.recursive_type_normalized(db, cycle)),
556+
(None, _) => None,
557+
};
558+
Self::Property(PropertyInstanceType::new(db, getter, setter))
559+
}
560+
(Self::Other(curr), Self::Other(prev)) => {
561+
Self::Other(curr.cycle_normalized(db, *prev, cycle))
562+
}
563+
_ => {
564+
debug_assert!(matches!(previous, Self::Other(ty) if ty.is_divergent()));
565+
*self
566+
}
567+
}
568+
}
569+
512570
fn apply_type_mapping_impl<'a>(
513571
&self,
514572
db: &'db dyn Db,
@@ -850,7 +908,11 @@ impl BoundOnClass {
850908
}
851909

852910
/// Inner Salsa query for [`ProtocolClass::interface`].
853-
#[salsa::tracked(cycle_initial=proto_interface_cycle_initial, heap_size=ruff_memory_usage::heap_size)]
911+
#[salsa::tracked(
912+
cycle_initial=proto_interface_cycle_initial,
913+
cycle_fn=proto_interface_cycle_recover,
914+
heap_size=ruff_memory_usage::heap_size,
915+
)]
854916
fn cached_protocol_interface<'db>(
855917
db: &'db dyn Db,
856918
class: ClassType<'db>,
@@ -971,6 +1033,17 @@ fn proto_interface_cycle_initial<'db>(
9711033
ProtocolInterface::empty(db)
9721034
}
9731035

1036+
#[allow(clippy::trivially_copy_pass_by_ref)]
1037+
fn proto_interface_cycle_recover<'db>(
1038+
db: &'db dyn Db,
1039+
cycle: &salsa::Cycle,
1040+
previous: &ProtocolInterface<'db>,
1041+
value: ProtocolInterface<'db>,
1042+
_class: ClassType<'db>,
1043+
) -> ProtocolInterface<'db> {
1044+
value.cycle_normalized(db, *previous, cycle)
1045+
}
1046+
9741047
/// Bind `self`, and *also* discard the functionlike-ness of the callable.
9751048
///
9761049
/// This additional upcasting is required in order for protocols with `__call__` method

crates/ty_python_semantic/src/types/signatures.rs

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,27 @@ impl<'db> CallableSignature<'db> {
109109
}))
110110
}
111111

112+
pub(crate) fn cycle_normalized(
113+
&self,
114+
db: &'db dyn Db,
115+
previous: &Self,
116+
cycle: &salsa::Cycle,
117+
) -> Self {
118+
if previous.overloads.len() == self.overloads.len() {
119+
Self {
120+
overloads: self
121+
.overloads
122+
.iter()
123+
.zip(previous.overloads.iter())
124+
.map(|(curr, prev)| curr.cycle_normalized(db, prev, cycle))
125+
.collect(),
126+
}
127+
} else {
128+
debug_assert_eq!(previous, &Self::bottom());
129+
self.clone()
130+
}
131+
}
132+
112133
pub(super) fn recursive_type_normalized_impl(
113134
&self,
114135
db: &'db dyn Db,
@@ -525,6 +546,32 @@ impl<'db> Signature<'db> {
525546
self
526547
}
527548

549+
fn cycle_normalized(&self, db: &'db dyn Db, previous: &Self, cycle: &salsa::Cycle) -> Self {
550+
let return_ty = self
551+
.return_ty
552+
.cycle_normalized(db, previous.return_ty, cycle);
553+
554+
let parameters = if self.parameters.len() == previous.parameters.len() {
555+
Parameters::new(
556+
db,
557+
self.parameters
558+
.iter()
559+
.zip(previous.parameters.iter())
560+
.map(|(curr, prev)| curr.cycle_normalized(db, prev, cycle)),
561+
)
562+
} else {
563+
debug_assert_eq!(previous.parameters, Parameters::bottom());
564+
self.parameters.clone()
565+
};
566+
567+
Self {
568+
generic_context: self.generic_context,
569+
definition: self.definition,
570+
parameters,
571+
return_ty,
572+
}
573+
}
574+
528575
pub(super) fn recursive_type_normalized_impl(
529576
&self,
530577
db: &'db dyn Db,
@@ -2980,6 +3027,22 @@ impl<'db> Parameter<'db> {
29803027
}
29813028
}
29823029

3030+
fn cycle_normalized(&self, db: &'db dyn Db, previous: &Self, cycle: &salsa::Cycle) -> Self {
3031+
let annotated_type =
3032+
self.annotated_type
3033+
.cycle_normalized(db, previous.annotated_type, cycle);
3034+
3035+
let kind = self.kind.cycle_normalized(db, &previous.kind, cycle);
3036+
3037+
Self {
3038+
annotated_type,
3039+
inferred_annotation: self.inferred_annotation,
3040+
has_starred_annotation: self.has_starred_annotation,
3041+
kind,
3042+
form: self.form,
3043+
}
3044+
}
3045+
29833046
pub(super) fn recursive_type_normalized_impl(
29843047
&self,
29853048
db: &'db dyn Db,
@@ -3222,6 +3285,59 @@ pub enum ParameterKind<'db> {
32223285
}
32233286

32243287
impl<'db> ParameterKind<'db> {
3288+
#[expect(clippy::ref_option)]
3289+
fn cycle_normalized_default(
3290+
db: &'db dyn Db,
3291+
current: &Option<Type<'db>>,
3292+
previous: &Option<Type<'db>>,
3293+
cycle: &salsa::Cycle,
3294+
) -> Option<Type<'db>> {
3295+
match (current, previous) {
3296+
(Some(curr), Some(prev)) => Some(curr.cycle_normalized(db, *prev, cycle)),
3297+
(Some(curr), None) => Some(curr.recursive_type_normalized(db, cycle)),
3298+
(None, _) => *current,
3299+
}
3300+
}
3301+
3302+
fn cycle_normalized(&self, db: &'db dyn Db, previous: &Self, cycle: &salsa::Cycle) -> Self {
3303+
match (self, previous) {
3304+
(
3305+
ParameterKind::PositionalOnly { name, default_type },
3306+
ParameterKind::PositionalOnly {
3307+
default_type: prev_default,
3308+
..
3309+
},
3310+
) => ParameterKind::PositionalOnly {
3311+
name: name.clone(),
3312+
default_type: Self::cycle_normalized_default(db, default_type, prev_default, cycle),
3313+
},
3314+
(
3315+
ParameterKind::PositionalOrKeyword { name, default_type },
3316+
ParameterKind::PositionalOrKeyword {
3317+
default_type: prev_default,
3318+
..
3319+
},
3320+
) => ParameterKind::PositionalOrKeyword {
3321+
name: name.clone(),
3322+
default_type: Self::cycle_normalized_default(db, default_type, prev_default, cycle),
3323+
},
3324+
(
3325+
ParameterKind::KeywordOnly { name, default_type },
3326+
ParameterKind::KeywordOnly {
3327+
default_type: prev_default,
3328+
..
3329+
},
3330+
) => ParameterKind::KeywordOnly {
3331+
name: name.clone(),
3332+
default_type: Self::cycle_normalized_default(db, default_type, prev_default, cycle),
3333+
},
3334+
// Variadic / KeywordVariadic have no types to normalize.
3335+
// Also, if the current `ParameterKind` is different from `previous`, it means that `previous` is the cycle initial value,
3336+
// and the current value should take precedence.
3337+
_ => self.clone(),
3338+
}
3339+
}
3340+
32253341
fn apply_type_mapping_impl<'a>(
32263342
&self,
32273343
db: &'db dyn Db,

0 commit comments

Comments
 (0)