Skip to content

Commit 46abff6

Browse files
authored
Harden CALL specialization guards and cache callables (#7360)
* vm: align CALL/CALL_KW specialization core guards with CPython * vm: keep specialization hot on misses and add heaptype getitem parity * vm: align call-alloc/getitem cache guards and call fastpath ordering * vm: align BINARY_OP, STORE_SUBSCR, UNPACK_SEQUENCE specialization guards * vm: finalize unicode/subscr specialization parity and regressions * vm: finalize specialization GC safety, tests, and cleanup
1 parent 45d8129 commit 46abff6

File tree

19 files changed

+1252
-581
lines changed

19 files changed

+1252
-581
lines changed

crates/derive-impl/src/pyclass.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1021,7 +1021,7 @@ where
10211021
.iter()
10221022
.any(|arg| matches!(arg, syn::FnArg::Receiver(_)));
10231023
let drop_first_typed = match self.inner.attr_name {
1024-
AttrName::Method | AttrName::ClassMethod if !has_receiver => 1,
1024+
AttrName::Method | AttrName::ClassMethod if !has_receiver && !raw => 1,
10251025
_ => 0,
10261026
};
10271027
let call_flags = infer_native_call_flags(func.sig(), drop_first_typed);

crates/vm/src/builtins/function.rs

Lines changed: 68 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,10 @@ impl PyFunction {
529529
}
530530

531531
impl Py<PyFunction> {
532+
pub(crate) fn is_optimized_for_call_specialization(&self) -> bool {
533+
self.code.flags.contains(bytecode::CodeFlags::OPTIMIZED)
534+
}
535+
532536
pub fn invoke_with_locals(
533537
&self,
534538
func_args: FuncArgs,
@@ -636,43 +640,90 @@ impl Py<PyFunction> {
636640
new_v
637641
}
638642

643+
/// function_kind(SIMPLE_FUNCTION) equivalent for CALL specialization.
644+
/// Returns true if: CO_OPTIMIZED, no VARARGS, no VARKEYWORDS, no kwonly args.
645+
pub(crate) fn is_simple_for_call_specialization(&self) -> bool {
646+
let code: &Py<PyCode> = &self.code;
647+
let flags = code.flags;
648+
flags.contains(bytecode::CodeFlags::OPTIMIZED)
649+
&& !flags.intersects(bytecode::CodeFlags::VARARGS | bytecode::CodeFlags::VARKEYWORDS)
650+
&& code.kwonlyarg_count == 0
651+
}
652+
639653
/// Check if this function is eligible for exact-args call specialization.
640-
/// Returns true if: no VARARGS, no VARKEYWORDS, no kwonly args, not generator/coroutine,
654+
/// Returns true if: CO_OPTIMIZED, no VARARGS, no VARKEYWORDS, no kwonly args,
641655
/// and effective_nargs matches co_argcount.
642656
pub(crate) fn can_specialize_call(&self, effective_nargs: u32) -> bool {
643657
let code: &Py<PyCode> = &self.code;
644658
let flags = code.flags;
645-
flags.contains(bytecode::CodeFlags::NEWLOCALS)
646-
&& !flags.intersects(
647-
bytecode::CodeFlags::VARARGS
648-
| bytecode::CodeFlags::VARKEYWORDS
649-
| bytecode::CodeFlags::GENERATOR
650-
| bytecode::CodeFlags::COROUTINE,
651-
)
659+
flags.contains(bytecode::CodeFlags::OPTIMIZED)
660+
&& !flags.intersects(bytecode::CodeFlags::VARARGS | bytecode::CodeFlags::VARKEYWORDS)
652661
&& code.kwonlyarg_count == 0
653662
&& code.arg_count == effective_nargs
654663
}
655664

665+
/// Runtime guard for CALL_*_EXACT_ARGS specialization: check only argcount.
666+
/// Other invariants are guaranteed by function versioning and specialization-time checks.
667+
#[inline]
668+
pub(crate) fn has_exact_argcount(&self, effective_nargs: u32) -> bool {
669+
self.code.arg_count == effective_nargs
670+
}
671+
672+
/// Bytes required for this function's frame on RustPython's thread datastack.
673+
/// Returns `None` for generator/coroutine code paths that do not push a
674+
/// regular datastack-backed frame in the fast call path.
675+
pub(crate) fn datastack_frame_size_bytes(&self) -> Option<usize> {
676+
let code: &Py<PyCode> = &self.code;
677+
if code
678+
.flags
679+
.intersects(bytecode::CodeFlags::GENERATOR | bytecode::CodeFlags::COROUTINE)
680+
{
681+
return None;
682+
}
683+
let nlocalsplus = code
684+
.varnames
685+
.len()
686+
.checked_add(code.cellvars.len())?
687+
.checked_add(code.freevars.len())?;
688+
let capacity = nlocalsplus.checked_add(code.max_stackdepth as usize)?;
689+
capacity.checked_mul(core::mem::size_of::<usize>())
690+
}
691+
656692
/// Fast path for calling a simple function with exact positional args.
657693
/// Skips FuncArgs allocation, prepend_arg, and fill_locals_from_args.
658-
/// Only valid when: no VARARGS, no VARKEYWORDS, no kwonlyargs, not generator/coroutine,
694+
/// Only valid when: CO_OPTIMIZED, no VARARGS, no VARKEYWORDS, no kwonlyargs,
659695
/// and nargs == co_argcount.
660696
pub fn invoke_exact_args(&self, mut args: Vec<PyObjectRef>, vm: &VirtualMachine) -> PyResult {
661697
let code: PyRef<PyCode> = (*self.code).to_owned();
662698

663699
debug_assert_eq!(args.len(), code.arg_count as usize);
664-
debug_assert!(code.flags.contains(bytecode::CodeFlags::NEWLOCALS));
665-
debug_assert!(!code.flags.intersects(
666-
bytecode::CodeFlags::VARARGS
667-
| bytecode::CodeFlags::VARKEYWORDS
668-
| bytecode::CodeFlags::GENERATOR
669-
| bytecode::CodeFlags::COROUTINE
670-
));
700+
debug_assert!(code.flags.contains(bytecode::CodeFlags::OPTIMIZED));
701+
debug_assert!(
702+
!code
703+
.flags
704+
.intersects(bytecode::CodeFlags::VARARGS | bytecode::CodeFlags::VARKEYWORDS)
705+
);
671706
debug_assert_eq!(code.kwonlyarg_count, 0);
672707

708+
// Generator/coroutine code objects are SIMPLE_FUNCTION in call
709+
// specialization classification, but their call path must still
710+
// go through invoke() to produce generator/coroutine objects.
711+
if code
712+
.flags
713+
.intersects(bytecode::CodeFlags::GENERATOR | bytecode::CodeFlags::COROUTINE)
714+
{
715+
return self.invoke(FuncArgs::from(args), vm);
716+
}
717+
718+
let locals = if code.flags.contains(bytecode::CodeFlags::NEWLOCALS) {
719+
None
720+
} else {
721+
Some(ArgMapping::from_dict_exact(self.globals.clone()))
722+
};
723+
673724
let frame = Frame::new(
674725
code.clone(),
675-
Scope::new(None, self.globals.clone()),
726+
Scope::new(locals, self.globals.clone()),
676727
self.builtins.clone(),
677728
self.closure.as_ref().map_or(&[], |c| c.as_slice()),
678729
Some(self.to_owned().into()),

crates/vm/src/builtins/list.rs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,16 @@ impl PyList {
286286

287287
fn _setitem(&self, needle: &PyObject, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
288288
match SequenceIndex::try_from_borrowed_object(vm, needle, "list")? {
289-
SequenceIndex::Int(index) => self.borrow_vec_mut().setitem_by_index(vm, index, value),
289+
SequenceIndex::Int(index) => self
290+
.borrow_vec_mut()
291+
.setitem_by_index(vm, index, value)
292+
.map_err(|e| {
293+
if e.class().is(vm.ctx.exceptions.index_error) {
294+
vm.new_index_error("list assignment index out of range".to_owned())
295+
} else {
296+
e
297+
}
298+
}),
290299
SequenceIndex::Slice(slice) => {
291300
let sec = extract_cloned(&value, Ok, vm)?;
292301
self.borrow_vec_mut().setitem_by_slice(vm, slice, &sec)
@@ -509,6 +518,13 @@ impl AsSequence for PyList {
509518
} else {
510519
zelf.borrow_vec_mut().delitem_by_index(vm, i)
511520
}
521+
.map_err(|e| {
522+
if e.class().is(vm.ctx.exceptions.index_error) {
523+
vm.new_index_error("list assignment index out of range".to_owned())
524+
} else {
525+
e
526+
}
527+
})
512528
}),
513529
contains: atomic_func!(|seq, target, vm| {
514530
let zelf = PyList::sequence_downcast(seq);

crates/vm/src/builtins/object.rs

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -64,19 +64,6 @@ impl Constructor for PyBaseObject {
6464
}
6565
}
6666

67-
// more or less __new__ operator
68-
// Only create dict if the class has HAS_DICT flag (i.e., __slots__ was not defined
69-
// or __dict__ is in __slots__)
70-
let dict = if cls
71-
.slots
72-
.flags
73-
.has_feature(crate::types::PyTypeFlags::HAS_DICT)
74-
{
75-
Some(vm.ctx.new_dict())
76-
} else {
77-
None
78-
};
79-
8067
// Ensure that all abstract methods are implemented before instantiating instance.
8168
if let Some(abs_methods) = cls.get_attr(identifier!(vm, __abstractmethods__))
8269
&& let Some(unimplemented_abstract_method_count) = abs_methods.length_opt(vm)
@@ -109,14 +96,29 @@ impl Constructor for PyBaseObject {
10996
}
11097
}
11198

112-
Ok(crate::PyRef::new_ref(Self, cls, dict).into())
99+
generic_alloc(cls, 0, vm)
113100
}
114101

115102
fn py_new(_cls: &Py<PyType>, _args: Self::Args, _vm: &VirtualMachine) -> PyResult<Self> {
116103
unimplemented!("use slot_new")
117104
}
118105
}
119106

107+
pub(crate) fn generic_alloc(cls: PyTypeRef, _nitems: usize, vm: &VirtualMachine) -> PyResult {
108+
// Only create dict if the class has HAS_DICT flag (i.e., __slots__ was not defined
109+
// or __dict__ is in __slots__)
110+
let dict = if cls
111+
.slots
112+
.flags
113+
.has_feature(crate::types::PyTypeFlags::HAS_DICT)
114+
{
115+
Some(vm.ctx.new_dict())
116+
} else {
117+
None
118+
};
119+
Ok(crate::PyRef::new_ref(PyBaseObject, cls, dict).into())
120+
}
121+
120122
impl Initializer for PyBaseObject {
121123
type Args = FuncArgs;
122124

@@ -561,8 +563,9 @@ pub fn object_set_dict(obj: PyObjectRef, dict: PyDictRef, vm: &VirtualMachine) -
561563
}
562564

563565
pub fn init(ctx: &'static Context) {
564-
// Manually set init slot - derive macro doesn't generate extend_slots
566+
// Manually set alloc/init slots - derive macro doesn't generate extend_slots
565567
// for trait impl that overrides #[pyslot] method
568+
ctx.types.object_type.slots.alloc.store(Some(generic_alloc));
566569
ctx.types
567570
.object_type
568571
.slots

crates/vm/src/builtins/str.rs

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1500,14 +1500,25 @@ impl PyRef<PyStr> {
15001500
}
15011501

15021502
pub fn concat_in_place(&mut self, other: &Wtf8, vm: &VirtualMachine) {
1503-
// TODO: call [A]Rc::get_mut on the str to try to mutate the data in place
15041503
if other.is_empty() {
15051504
return;
15061505
}
15071506
let mut s = Wtf8Buf::with_capacity(self.byte_len() + other.len());
15081507
s.push_wtf8(self.as_ref());
15091508
s.push_wtf8(other);
1510-
*self = PyStr::from(s).into_ref(&vm.ctx);
1509+
if self.as_object().strong_count() == 1 {
1510+
// SAFETY: strong_count()==1 guarantees unique ownership of this PyStr.
1511+
// Mutating payload in place preserves semantics while avoiding PyObject reallocation.
1512+
unsafe {
1513+
let payload = self.payload() as *const PyStr as *mut PyStr;
1514+
(*payload).data = PyStr::from(s).data;
1515+
(*payload)
1516+
.hash
1517+
.store(hash::SENTINEL, atomic::Ordering::Relaxed);
1518+
}
1519+
} else {
1520+
*self = PyStr::from(s).into_ref(&vm.ctx);
1521+
}
15111522
}
15121523

15131524
pub fn try_into_utf8(self, vm: &VirtualMachine) -> PyResult<PyRef<PyUtf8Str>> {
@@ -1678,13 +1689,23 @@ impl ToPyObject for Wtf8Buf {
16781689

16791690
impl ToPyObject for char {
16801691
fn to_pyobject(self, vm: &VirtualMachine) -> PyObjectRef {
1681-
vm.ctx.new_str(self).into()
1692+
let cp = self as u32;
1693+
if cp <= u8::MAX as u32 {
1694+
vm.ctx.latin1_char_cache[cp as usize].clone().into()
1695+
} else {
1696+
vm.ctx.new_str(self).into()
1697+
}
16821698
}
16831699
}
16841700

16851701
impl ToPyObject for CodePoint {
16861702
fn to_pyobject(self, vm: &VirtualMachine) -> PyObjectRef {
1687-
vm.ctx.new_str(self).into()
1703+
let cp = self.to_u32();
1704+
if cp <= u8::MAX as u32 {
1705+
vm.ctx.latin1_char_cache[cp as usize].clone().into()
1706+
} else {
1707+
vm.ctx.new_str(self).into()
1708+
}
16881709
}
16891710
}
16901711

0 commit comments

Comments
 (0)