Skip to content

Commit 0fc9ffe

Browse files
committed
vm: cache __init__ lookup for class-call specialization
Cache __init__ PyFunction in HeapTypeExt to avoid repeated MRO lookup in CALL_ALLOC_AND_ENTER_INIT. Invalidated on type.modified(). Add eval_frame_active guard and improve __init__ return error message.
1 parent e16c3c2 commit 0fc9ffe

File tree

2 files changed

+73
-41
lines changed

2 files changed

+73
-41
lines changed

crates/vm/src/builtins/type.rs

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use crate::{
1111
MemberGetter, MemberKind, MemberSetter, PyDescriptorOwned, PyMemberDef,
1212
PyMemberDescriptor,
1313
},
14-
function::PyCellRef,
14+
function::{PyCellRef, PyFunction},
1515
tuple::{IntoPyTuple, PyTuple},
1616
},
1717
class::{PyClassImpl, StaticType},
@@ -269,6 +269,7 @@ pub struct HeapTypeExt {
269269
pub qualname: PyRwLock<PyStrRef>,
270270
pub slots: Option<PyRef<PyTuple<PyStrRef>>>,
271271
pub type_data: PyRwLock<Option<TypeDataSlot>>,
272+
pub specialization_init: PyRwLock<Option<PyRef<PyFunction>>>,
272273
}
273274

274275
pub struct PointerSlot<T>(NonNull<T>);
@@ -396,6 +397,9 @@ impl PyType {
396397

397398
/// Invalidate this type's version tag and cascade to all subclasses.
398399
pub fn modified(&self) {
400+
if let Some(ext) = self.heaptype_ext.as_ref() {
401+
*ext.specialization_init.write() = None;
402+
}
399403
// If already invalidated, all subclasses must also be invalidated
400404
// (guaranteed by the MRO invariant in assign_version_tag).
401405
let old_version = self.tp_version_tag.load(Ordering::Acquire);
@@ -450,6 +454,7 @@ impl PyType {
450454
qualname: PyRwLock::new(name),
451455
slots: None,
452456
type_data: PyRwLock::new(None),
457+
specialization_init: PyRwLock::new(None),
453458
};
454459
let base = bases[0].clone();
455460

@@ -769,6 +774,38 @@ impl PyType {
769774
self.find_name_in_mro(attr_name)
770775
}
771776

777+
/// Cache __init__ for CALL_ALLOC_AND_ENTER_INIT specialization.
778+
/// The cache is valid only when guarded by the type version check.
779+
pub(crate) fn cache_init_for_specialization(
780+
&self,
781+
init: PyRef<PyFunction>,
782+
tp_version: u32,
783+
) -> bool {
784+
let Some(ext) = self.heaptype_ext.as_ref() else {
785+
return false;
786+
};
787+
if tp_version == 0 || self.tp_version_tag.load(Ordering::Acquire) != tp_version {
788+
return false;
789+
}
790+
*ext.specialization_init.write() = Some(init);
791+
true
792+
}
793+
794+
/// Read cached __init__ for CALL_ALLOC_AND_ENTER_INIT specialization.
795+
pub(crate) fn get_cached_init_for_specialization(
796+
&self,
797+
tp_version: u32,
798+
) -> Option<PyRef<PyFunction>> {
799+
let ext = self.heaptype_ext.as_ref()?;
800+
if tp_version == 0 || self.tp_version_tag.load(Ordering::Acquire) != tp_version {
801+
return None;
802+
}
803+
ext.specialization_init
804+
.read()
805+
.as_ref()
806+
.map(|init| init.to_owned())
807+
}
808+
772809
pub fn get_direct_attr(&self, attr_name: &'static PyStrInterned) -> Option<PyObjectRef> {
773810
self.attributes.read().get(attr_name).cloned()
774811
}
@@ -1879,6 +1916,7 @@ impl Constructor for PyType {
18791916
qualname: PyRwLock::new(qualname),
18801917
slots: heaptype_slots.clone(),
18811918
type_data: PyRwLock::new(None),
1919+
specialization_init: PyRwLock::new(None),
18821920
};
18831921
(slots, heaptype_ext)
18841922
};

crates/vm/src/frame.rs

Lines changed: 34 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -4603,56 +4603,48 @@ impl ExecutingFrame<'_> {
46034603
.localsplus
46044604
.stack_index(stack_len - nargs as usize - 1)
46054605
.is_some();
4606-
if !self_or_null_is_some
4606+
if !self.specialization_eval_frame_active(vm)
4607+
&& !self_or_null_is_some
46074608
&& cached_version != 0
46084609
&& let Some(cls) = callable.downcast_ref::<PyType>()
46094610
&& cls.tp_version_tag.load(Acquire) == cached_version
4611+
&& let Some(init_func) = cls.get_cached_init_for_specialization(cached_version)
46104612
{
4611-
// Look up __init__ (guarded by type_version)
4612-
if let Some(init) = cls.get_attr(identifier!(vm, __init__))
4613-
&& let Some(init_func) = init.downcast_ref_if_exact::<PyFunction>(vm)
4614-
&& init_func.can_specialize_call(nargs + 1)
4613+
// Allocate object directly (tp_new == object.__new__)
4614+
let dict = if cls
4615+
.slots
4616+
.flags
4617+
.has_feature(crate::types::PyTypeFlags::HAS_DICT)
46154618
{
4616-
// Allocate object directly (tp_new == object.__new__)
4617-
let dict = if cls
4618-
.slots
4619-
.flags
4620-
.has_feature(crate::types::PyTypeFlags::HAS_DICT)
4621-
{
4622-
Some(vm.ctx.new_dict())
4623-
} else {
4624-
None
4625-
};
4626-
let cls_ref = cls.to_owned();
4627-
let new_obj: PyObjectRef =
4628-
PyRef::new_ref(PyBaseObject, cls_ref, dict).into();
4629-
4630-
// Build args: [new_obj, arg1, ..., argN]
4631-
let pos_args: Vec<PyObjectRef> =
4632-
self.pop_multiple(nargs as usize).collect();
4633-
let _null = self.pop_value_opt(); // self_or_null (None)
4634-
let _callable = self.pop_value(); // callable (type)
4619+
Some(vm.ctx.new_dict())
4620+
} else {
4621+
None
4622+
};
4623+
let cls_ref = cls.to_owned();
4624+
let new_obj: PyObjectRef = PyRef::new_ref(PyBaseObject, cls_ref, dict).into();
46354625

4636-
let mut all_args = Vec::with_capacity(pos_args.len() + 1);
4637-
all_args.push(new_obj.clone());
4638-
all_args.extend(pos_args);
4626+
// Build args: [new_obj, arg1, ..., argN]
4627+
let pos_args: Vec<PyObjectRef> = self.pop_multiple(nargs as usize).collect();
4628+
let _null = self.pop_value_opt(); // self_or_null (None)
4629+
let _callable = self.pop_value(); // callable (type)
46394630

4640-
let init_result = init_func.invoke_exact_args(all_args, vm)?;
4631+
let mut all_args = Vec::with_capacity(pos_args.len() + 1);
4632+
all_args.push(new_obj.clone());
4633+
all_args.extend(pos_args);
46414634

4642-
// EXIT_INIT_CHECK: __init__ must return None
4643-
if !vm.is_none(&init_result) {
4644-
return Err(
4645-
vm.new_type_error("__init__() should return None".to_owned())
4646-
);
4647-
}
4635+
let init_result = init_func.invoke_exact_args(all_args, vm)?;
46484636

4649-
self.push_value(new_obj);
4650-
return Ok(None);
4637+
// EXIT_INIT_CHECK: __init__ must return None
4638+
if !vm.is_none(&init_result) {
4639+
return Err(vm.new_type_error(format!(
4640+
"__init__() should return None, not '{}'",
4641+
init_result.class().name()
4642+
)));
46514643
}
4644+
4645+
self.push_value(new_obj);
4646+
return Ok(None);
46524647
}
4653-
self.deoptimize(Instruction::Call {
4654-
argc: Arg::marker(),
4655-
});
46564648
self.execute_call_vectorcall(nargs, vm)
46574649
}
46584650
Instruction::CallMethodDescriptorFastWithKeywords => {
@@ -7997,7 +7989,9 @@ impl ExecutingFrame<'_> {
79977989
&& init_func.can_specialize_call(nargs + 1)
79987990
{
79997991
let version = cls.tp_version_tag.load(Acquire);
8000-
if version != 0 {
7992+
if version != 0
7993+
&& cls.cache_init_for_specialization(init_func.to_owned(), version)
7994+
{
80017995
unsafe {
80027996
self.code
80037997
.instructions

0 commit comments

Comments
 (0)