Skip to content

Commit f56a0b3

Browse files
committed
Fix GC use-after-free via atomic try_to_owned
The GC's strong reference creation had a TOCTOU race: between checking strong_count() > 0 and calling to_owned() (which calls inc()), another thread could dec() the count to 0 and proceed with deallocation. For objects without __del__, no resurrection check occurs, so the memory is freed while GC holds a dangling reference. The subsequent drop accesses freed memory, corrupting malloc metadata (malloc(): unaligned tcache chunk detected). Fix by replacing the check-then-act pattern with CAS-based try_to_owned()/safe_inc() that atomically verifies count > 0 and increments. Apply the same fix to WeakRefList callback collection. Also add atomic dict snapshot for list(dict) to prevent RuntimeError during concurrent dict iteration. Unskip test_thread_safety which now passes reliably.
1 parent 62766fd commit f56a0b3

File tree

5 files changed

+40
-28
lines changed

5 files changed

+40
-28
lines changed

Lib/test/_test_multiprocessing.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4815,8 +4815,6 @@ def test_finalize(self):
48154815
result = [obj for obj in iter(conn.recv, 'STOP')]
48164816
self.assertEqual(result, ['a', 'b', 'd10', 'd03', 'd02', 'd01', 'e'])
48174817

4818-
# TODO: RUSTPYTHON - gc.get_threshold() and gc.set_threshold() not implemented
4819-
@unittest.expectedFailure
48204818
@support.requires_resource('cpu')
48214819
def test_thread_safety(self):
48224820
# bpo-24484: _run_finalizers() should be thread-safe

crates/common/src/refcount.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ impl RefCount {
123123
pub fn safe_inc(&self) -> bool {
124124
let mut old = State::from_raw(self.state.load(Ordering::Relaxed));
125125
loop {
126-
if old.destructed() {
126+
if old.destructed() || old.strong() == 0 {
127127
return false;
128128
}
129129
if (old.strong() as usize) >= STRONG {

crates/vm/src/gc_state.rs

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -299,13 +299,7 @@ impl GcState {
299299
fn collect_from_list(
300300
list: &LinkedList<GcLink, PyObject>,
301301
) -> impl Iterator<Item = PyObjectRef> + '_ {
302-
list.iter().filter_map(|obj| {
303-
if obj.strong_count() > 0 {
304-
Some(obj.to_owned())
305-
} else {
306-
None
307-
}
308-
})
302+
list.iter().filter_map(|obj| obj.try_to_owned())
309303
}
310304

311305
match generation {
@@ -468,27 +462,24 @@ impl GcState {
468462
// After dropping gen_locks, other threads can untrack+free objects,
469463
// making the raw pointers in `reachable`/`unreachable` dangling.
470464
// Strong refs keep objects alive for later phases.
465+
//
466+
// Use try_to_owned() (CAS-based) instead of strong_count()+to_owned()
467+
// to prevent a TOCTOU race: another thread can dec() the count to 0
468+
// between the check and the increment, causing a use-after-free when
469+
// the destroying thread eventually frees the memory.
471470
let survivor_refs: Vec<PyObjectRef> = reachable
472471
.iter()
473472
.filter_map(|ptr| {
474473
let obj = unsafe { ptr.0.as_ref() };
475-
if obj.strong_count() > 0 {
476-
Some(obj.to_owned())
477-
} else {
478-
None
479-
}
474+
obj.try_to_owned()
480475
})
481476
.collect();
482477

483478
let unreachable_refs: Vec<crate::PyObjectRef> = unreachable
484479
.iter()
485480
.filter_map(|ptr| {
486481
let obj = unsafe { ptr.0.as_ref() };
487-
if obj.strong_count() > 0 {
488-
Some(obj.to_owned())
489-
} else {
490-
None
491-
}
482+
obj.try_to_owned()
492483
})
493484
.collect();
494485

crates/vm/src/object/core.rs

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -576,11 +576,12 @@ impl WeakRefList {
576576
ptrs.as_mut().set_next(None);
577577
}
578578

579-
// Collect callback if present and weakref is still alive
580-
if wr.0.ref_count.get() > 0 {
579+
// Collect callback only if we can still acquire a strong ref.
580+
if wr.0.ref_count.safe_inc() {
581+
let wr_ref = unsafe { PyRef::from_raw(wr as *const Py<PyWeak>) };
581582
let cb = unsafe { wr.0.payload.callback.get().replace(None) };
582583
if let Some(cb) = cb {
583-
callbacks.push((wr.to_owned(), cb));
584+
callbacks.push((wr_ref, cb));
584585
}
585586
}
586587

@@ -626,11 +627,12 @@ impl WeakRefList {
626627
ptrs.as_mut().set_next(None);
627628
}
628629

629-
// Collect callback without invoking
630-
if wr.0.ref_count.get() > 0 {
630+
// Collect callback without invoking only if we can keep weakref alive.
631+
if wr.0.ref_count.safe_inc() {
632+
let wr_ref = unsafe { PyRef::from_raw(wr as *const Py<PyWeak>) };
631633
let cb = unsafe { wr.0.payload.callback.get().replace(None) };
632634
if let Some(cb) = cb {
633-
callbacks.push((wr.to_owned(), cb));
635+
callbacks.push((wr_ref, cb));
634636
}
635637
}
636638

@@ -660,8 +662,8 @@ impl WeakRefList {
660662
let mut current = NonNull::new(self.head.load(Ordering::Relaxed));
661663
while let Some(node) = current {
662664
let wr = unsafe { node.as_ref() };
663-
if wr.0.ref_count.get() > 0 {
664-
v.push(wr.to_owned());
665+
if wr.0.ref_count.safe_inc() {
666+
v.push(unsafe { PyRef::from_raw(wr as *const Py<PyWeak>) });
665667
}
666668
current = unsafe { WeakLink::pointers(node).as_ref().get_next() };
667669
}
@@ -952,6 +954,23 @@ impl ToOwned for PyObject {
952954
}
953955
}
954956

957+
impl PyObject {
958+
/// Atomically try to create a strong reference.
959+
/// Returns `None` if the strong count is already 0 (object being destroyed).
960+
/// Uses CAS to prevent the TOCTOU race between checking strong_count and
961+
/// incrementing it.
962+
#[inline]
963+
pub fn try_to_owned(&self) -> Option<PyObjectRef> {
964+
if self.0.ref_count.safe_inc() {
965+
Some(PyObjectRef {
966+
ptr: NonNull::from(self),
967+
})
968+
} else {
969+
None
970+
}
971+
}
972+
}
973+
955974
impl PyObjectRef {
956975
#[inline(always)]
957976
pub const fn into_raw(self) -> NonNull<PyObject> {

crates/vm/src/vm/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1319,6 +1319,10 @@ impl VirtualMachine {
13191319
} else if cls.is(self.ctx.types.list_type) {
13201320
list_borrow = value.downcast_ref::<PyList>().unwrap().borrow_vec();
13211321
&list_borrow
1322+
} else if cls.is(self.ctx.types.dict_type) {
1323+
// Atomic snapshot of dict keys for thread-safe iteration.
1324+
let keys = value.downcast_ref::<PyDict>().unwrap().keys_vec();
1325+
return keys.into_iter().map(func).collect();
13221326
} else if cls.is(self.ctx.types.dict_keys_type) {
13231327
// Atomic snapshot of dict keys - prevents race condition during iteration
13241328
let keys = value.downcast_ref::<PyDictKeys>().unwrap().dict.keys_vec();

0 commit comments

Comments
 (0)