Skip to content

Commit 5defc42

Browse files
committed
Replace PyFunction.code PyMutex with PyAtomicRef for lock-free reads
Change the code field from PyMutex<PyRef<PyCode>> to PyAtomicRef<PyCode>, eliminating mutex lock/unlock on every function call. The setter uses swap_to_temporary_refs for safe deferred drop of the old code object.
1 parent 0a6a6f8 commit 5defc42

File tree

2 files changed

+15
-18
lines changed

2 files changed

+15
-18
lines changed

crates/vm/src/builtins/function.rs

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use super::{
99
use crate::common::lock::OnceCell;
1010
use crate::common::lock::PyMutex;
1111
use crate::function::ArgMapping;
12-
use crate::object::{Traverse, TraverseFn};
12+
use crate::object::{PyAtomicRef, Traverse, TraverseFn};
1313
use crate::{
1414
AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
1515
bytecode,
@@ -61,7 +61,7 @@ fn format_missing_args(
6161
#[pyclass(module = false, name = "function", traverse = "manual")]
6262
#[derive(Debug)]
6363
pub struct PyFunction {
64-
code: PyMutex<PyRef<PyCode>>,
64+
code: PyAtomicRef<PyCode>,
6565
globals: PyDictRef,
6666
builtins: PyObjectRef,
6767
closure: Option<PyRef<PyTuple<PyCellRef>>>,
@@ -192,7 +192,7 @@ impl PyFunction {
192192

193193
let qualname = vm.ctx.new_str(code.qualname.as_str());
194194
let func = Self {
195-
code: PyMutex::new(code.clone()),
195+
code: PyAtomicRef::from(code.clone()),
196196
globals,
197197
builtins,
198198
closure: None,
@@ -217,7 +217,7 @@ impl PyFunction {
217217
func_args: FuncArgs,
218218
vm: &VirtualMachine,
219219
) -> PyResult<()> {
220-
let code = &*self.code.lock();
220+
let code: &Py<PyCode> = &self.code;
221221
let nargs = func_args.args.len();
222222
let n_expected_args = code.arg_count as usize;
223223
let total_args = code.arg_count as usize + code.kwonlyarg_count as usize;
@@ -539,13 +539,12 @@ impl Py<PyFunction> {
539539
Err(err) => info!(
540540
"jit: function `{}` is falling back to being interpreted because of the \
541541
error: {}",
542-
self.code.lock().obj_name,
543-
err
542+
self.code.obj_name, err
544543
),
545544
}
546545
}
547546

548-
let code = self.code.lock().clone();
547+
let code: PyRef<PyCode> = (*self.code).to_owned();
549548

550549
let locals = if code.flags.contains(bytecode::CodeFlags::NEWLOCALS) {
551550
ArgMapping::from_dict_exact(vm.ctx.new_dict())
@@ -609,7 +608,7 @@ impl Py<PyFunction> {
609608
/// Returns true if: no VARARGS, no VARKEYWORDS, no kwonly args, not generator/coroutine,
610609
/// and effective_nargs matches co_argcount.
611610
pub(crate) fn can_specialize_call(&self, effective_nargs: u32) -> bool {
612-
let code = self.code.lock();
611+
let code: &Py<PyCode> = &self.code;
613612
let flags = code.flags;
614613
flags.contains(bytecode::CodeFlags::NEWLOCALS)
615614
&& !flags.intersects(
@@ -627,7 +626,7 @@ impl Py<PyFunction> {
627626
/// Only valid when: no VARARGS, no VARKEYWORDS, no kwonlyargs, not generator/coroutine,
628627
/// and nargs == co_argcount.
629628
pub fn invoke_exact_args(&self, args: &[PyObjectRef], vm: &VirtualMachine) -> PyResult {
630-
let code = self.code.lock().clone();
629+
let code: PyRef<PyCode> = (*self.code).to_owned();
631630

632631
let locals = ArgMapping::from_dict_exact(vm.ctx.new_dict());
633632

@@ -676,12 +675,12 @@ impl PyPayload for PyFunction {
676675
impl PyFunction {
677676
#[pygetset]
678677
fn __code__(&self) -> PyRef<PyCode> {
679-
self.code.lock().clone()
678+
(*self.code).to_owned()
680679
}
681680

682681
#[pygetset(setter)]
683-
fn set___code__(&self, code: PyRef<PyCode>) {
684-
*self.code.lock() = code;
682+
fn set___code__(&self, code: PyRef<PyCode>, vm: &VirtualMachine) {
683+
self.code.swap_to_temporary_refs(code, vm);
685684
self.func_version.store(0, Relaxed);
686685
}
687686

@@ -923,7 +922,7 @@ impl PyFunction {
923922
}
924923
let arg_types = jit::get_jit_arg_types(&zelf, vm)?;
925924
let ret_type = jit::jit_ret_type(&zelf, vm)?;
926-
let code = zelf.code.lock();
925+
let code: &Py<PyCode> = &zelf.code;
927926
let compiled = rustpython_jit::compile(&code.code, &arg_types, ret_type)
928927
.map_err(|err| jit::new_jit_error(err.to_string(), vm))?;
929928
let _ = zelf.jitted_code.set(compiled);

crates/vm/src/builtins/function/jit.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use crate::{
22
AsObject, Py, PyObject, PyObjectRef, PyResult, TryFromObject, VirtualMachine,
33
builtins::{
4-
PyBaseExceptionRef, PyDict, PyDictRef, PyFunction, PyStrInterned, bool_, float, int,
4+
PyBaseExceptionRef, PyCode, PyDict, PyDictRef, PyFunction, PyStrInterned, bool_, float, int,
55
},
66
bytecode::CodeFlags,
77
convert::ToPyObject,
@@ -67,7 +67,7 @@ fn get_jit_arg_type(dict: &Py<PyDict>, name: &str, vm: &VirtualMachine) -> PyRes
6767
}
6868

6969
pub fn get_jit_arg_types(func: &Py<PyFunction>, vm: &VirtualMachine) -> PyResult<Vec<JitType>> {
70-
let code = func.code.lock();
70+
let code: &Py<PyCode> = &func.code;
7171
let arg_names = code.arg_names();
7272

7373
if code
@@ -160,7 +160,7 @@ pub(crate) fn get_jit_args<'a>(
160160
let mut jit_args = jitted_code.args_builder();
161161
let nargs = func_args.args.len();
162162

163-
let code = func.code.lock();
163+
let code: &Py<PyCode> = &func.code;
164164
let arg_names = code.arg_names();
165165
let arg_count = code.arg_count;
166166
let posonlyarg_count = code.posonlyarg_count;
@@ -220,7 +220,5 @@ pub(crate) fn get_jit_args<'a>(
220220
}
221221
}
222222

223-
drop(code);
224-
225223
jit_args.into_args().ok_or(ArgsError::NotAllArgsPassed)
226224
}

0 commit comments

Comments
 (0)