Skip to content

Commit 70169e8

Browse files
committed
Fix thread-safety in GC, type cache, and instruction cache
GC / refcount: - Add safe_inc() check for strong()==0 in RefCount - Add try_to_owned() to PyObject for atomic refcount acquire - Replace strong_count()+to_owned() with try_to_owned() in GC collection and weakref callback paths to prevent TOCTOU races Type cache: - Add proper SeqLock (sequence counter) to TypeCacheEntry - Readers spin-wait on odd sequence, validate after read - Writers bracket updates with begin_write/end_write - Use try_to_owned + pointer revalidation on read path - Call modified() BEFORE attribute modification in set_attr Instruction cache: - Add pointer_cache (AtomicUsize array) to CodeUnits for single atomic pointer load/store (prevents torn reads) - Add try_read_cached_descriptor with try_to_owned + pointer and version revalidation after increment - Add write_cached_descriptor with version-bracketed writes RLock: - Fix release() to check is_owned_by_current_thread - Add _release_save/_acquire_restore methods
1 parent 5ab631d commit 70169e8

File tree

9 files changed

+309
-169
lines changed

9 files changed

+309
-169
lines changed

Lib/test/_test_multiprocessing.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4815,8 +4815,6 @@ def test_finalize(self):
48154815
result = [obj for obj in iter(conn.recv, 'STOP')]
48164816
self.assertEqual(result, ['a', 'b', 'd10', 'd03', 'd02', 'd01', 'e'])
48174817

4818-
# TODO: RUSTPYTHON - gc.get_threshold() and gc.set_threshold() not implemented
4819-
@unittest.expectedFailure
48204818
@support.requires_resource('cpu')
48214819
def test_thread_safety(self):
48224820
# bpo-24484: _run_finalizers() should be thread-safe

crates/common/src/refcount.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ impl RefCount {
123123
pub fn safe_inc(&self) -> bool {
124124
let mut old = State::from_raw(self.state.load(Ordering::Relaxed));
125125
loop {
126-
if old.destructed() {
126+
if old.destructed() || old.strong() == 0 {
127127
return false;
128128
}
129129
if (old.strong() as usize) >= STRONG {

crates/compiler-core/src/bytecode.rs

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use core::{
1212
cell::UnsafeCell,
1313
hash, mem,
1414
ops::Deref,
15-
sync::atomic::{AtomicU8, AtomicU16, Ordering},
15+
sync::atomic::{AtomicU8, AtomicU16, AtomicUsize, Ordering},
1616
};
1717
use itertools::Itertools;
1818
use malachite_bigint::BigInt;
@@ -411,6 +411,10 @@ impl TryFrom<&[u8]> for CodeUnit {
411411
pub struct CodeUnits {
412412
units: UnsafeCell<Box<[CodeUnit]>>,
413413
adaptive_counters: Box<[AtomicU16]>,
414+
/// Pointer-sized cache entries for descriptor pointers.
415+
/// Single atomic load/store prevents torn reads when multiple threads
416+
/// specialize the same instruction concurrently.
417+
pointer_cache: Box<[AtomicUsize]>,
414418
}
415419

416420
// SAFETY: All cache operations use atomic read/write instructions.
@@ -432,9 +436,15 @@ impl Clone for CodeUnits {
432436
.iter()
433437
.map(|c| AtomicU16::new(c.load(Ordering::Relaxed)))
434438
.collect();
439+
let pointer_cache = self
440+
.pointer_cache
441+
.iter()
442+
.map(|c| AtomicUsize::new(c.load(Ordering::Relaxed)))
443+
.collect();
435444
Self {
436445
units: UnsafeCell::new(units),
437446
adaptive_counters,
447+
pointer_cache,
438448
}
439449
}
440450
}
@@ -472,13 +482,19 @@ impl<const N: usize> From<[CodeUnit; N]> for CodeUnits {
472482
impl From<Vec<CodeUnit>> for CodeUnits {
473483
fn from(value: Vec<CodeUnit>) -> Self {
474484
let units = value.into_boxed_slice();
475-
let adaptive_counters = (0..units.len())
485+
let len = units.len();
486+
let adaptive_counters = (0..len)
476487
.map(|_| AtomicU16::new(0))
477488
.collect::<Vec<_>>()
478489
.into_boxed_slice();
490+
let pointer_cache = (0..len)
491+
.map(|_| AtomicUsize::new(0))
492+
.collect::<Vec<_>>()
493+
.into_boxed_slice();
479494
Self {
480495
units: UnsafeCell::new(units),
481496
adaptive_counters,
497+
pointer_cache,
482498
}
483499
}
484500
}
@@ -600,25 +616,25 @@ impl CodeUnits {
600616
lo | (hi << 16)
601617
}
602618

603-
/// Write a u64 value across four consecutive CACHE code units starting at `index`.
619+
/// Store a pointer-sized value atomically in the pointer cache at `index`.
620+
///
621+
/// Uses a single `AtomicUsize` store to prevent torn writes when
622+
/// multiple threads specialize the same instruction concurrently.
604623
///
605624
/// # Safety
606-
/// Same requirements as `write_cache_u16`.
607-
pub unsafe fn write_cache_u64(&self, index: usize, value: u64) {
608-
unsafe {
609-
self.write_cache_u32(index, value as u32);
610-
self.write_cache_u32(index + 2, (value >> 32) as u32);
611-
}
625+
/// - `index` must be in bounds.
626+
pub unsafe fn write_cache_ptr(&self, index: usize, value: usize) {
627+
self.pointer_cache[index].store(value, Ordering::Relaxed);
612628
}
613629

614-
/// Read a u64 value from four consecutive CACHE code units starting at `index`.
630+
/// Load a pointer-sized value atomically from the pointer cache at `index`.
631+
///
632+
/// Uses a single `AtomicUsize` load to prevent torn reads.
615633
///
616634
/// # Panics
617-
/// Panics if `index + 3` is out of bounds.
618-
pub fn read_cache_u64(&self, index: usize) -> u64 {
619-
let lo = self.read_cache_u32(index) as u64;
620-
let hi = self.read_cache_u32(index + 2) as u64;
621-
lo | (hi << 32)
635+
/// Panics if `index` is out of bounds.
636+
pub fn read_cache_ptr(&self, index: usize) -> usize {
637+
self.pointer_cache[index].load(Ordering::Relaxed)
622638
}
623639

624640
/// Read adaptive counter bits for instruction at `index`.

crates/vm/src/builtins/type.rs

Lines changed: 93 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -64,17 +64,18 @@ static NEXT_TYPE_VERSION: AtomicU32 = AtomicU32::new(1);
6464
// Method cache (type_cache / MCACHE): direct-mapped cache keyed by
6565
// (tp_version_tag, interned_name_ptr).
6666
//
67-
// Uses a lock-free SeqLock pattern:
68-
// - version acts as both cache key AND sequence counter
69-
// - Read: load version (Acquire), read value ptr, re-check version
70-
// - Write: set version=0 (invalidate), store value, store version (Release)
67+
// Uses a lock-free SeqLock pattern for the read/write protocol:
68+
// - Readers validate sequence/version/name before and after the value read.
69+
// - Writers bracket updates with sequence odd/even transitions.
7170
// No mutex needed on the hot path (cache hit).
7271

7372
const TYPE_CACHE_SIZE_EXP: u32 = 12;
7473
const TYPE_CACHE_SIZE: usize = 1 << TYPE_CACHE_SIZE_EXP;
7574
const TYPE_CACHE_MASK: usize = TYPE_CACHE_SIZE - 1;
7675

7776
struct TypeCacheEntry {
77+
/// Sequence lock (odd = write in progress, even = quiescent).
78+
sequence: AtomicU32,
7879
/// tp_version_tag at cache time. 0 = empty/invalid.
7980
version: AtomicU32,
8081
/// Interned attribute name pointer (pointer equality check).
@@ -94,12 +95,39 @@ unsafe impl Sync for TypeCacheEntry {}
9495
impl TypeCacheEntry {
9596
fn new() -> Self {
9697
Self {
98+
sequence: AtomicU32::new(0),
9799
version: AtomicU32::new(0),
98100
name: AtomicPtr::new(core::ptr::null_mut()),
99101
value: AtomicPtr::new(core::ptr::null_mut()),
100102
}
101103
}
102104

105+
#[inline]
106+
fn begin_write(&self) {
107+
self.sequence.fetch_add(1, Ordering::AcqRel);
108+
}
109+
110+
#[inline]
111+
fn end_write(&self) {
112+
self.sequence.fetch_add(1, Ordering::Release);
113+
}
114+
115+
#[inline]
116+
fn begin_read(&self) -> u32 {
117+
let mut sequence = self.sequence.load(Ordering::Acquire);
118+
while (sequence & 1) != 0 {
119+
core::hint::spin_loop();
120+
sequence = self.sequence.load(Ordering::Acquire);
121+
}
122+
sequence
123+
}
124+
125+
#[inline]
126+
fn end_read(&self, previous: u32) -> bool {
127+
core::sync::atomic::fence(Ordering::Acquire);
128+
self.sequence.load(Ordering::Relaxed) == previous
129+
}
130+
103131
/// Take the value out of this entry, returning the owned PyObjectRef.
104132
/// Caller must ensure no concurrent reads can observe this entry
105133
/// (version should be set to 0 first).
@@ -137,10 +165,14 @@ fn type_cache_clear_version(version: u32) {
137165
let mut to_drop = Vec::new();
138166
for entry in TYPE_CACHE.iter() {
139167
if entry.version.load(Ordering::Relaxed) == version {
140-
entry.version.store(0, Ordering::Release);
141-
if let Some(v) = entry.take_value() {
142-
to_drop.push(v);
168+
entry.begin_write();
169+
if entry.version.load(Ordering::Relaxed) == version {
170+
entry.version.store(0, Ordering::Release);
171+
if let Some(v) = entry.take_value() {
172+
to_drop.push(v);
173+
}
143174
}
175+
entry.end_write();
144176
}
145177
}
146178
drop(to_drop);
@@ -158,10 +190,12 @@ pub fn type_cache_clear() {
158190
// Invalidate all entries and collect values.
159191
let mut to_drop = Vec::new();
160192
for entry in TYPE_CACHE.iter() {
193+
entry.begin_write();
161194
entry.version.store(0, Ordering::Release);
162195
if let Some(v) = entry.take_value() {
163196
to_drop.push(v);
164197
}
198+
entry.end_write();
165199
}
166200
drop(to_drop);
167201
TYPE_CACHE_CLEARING.store(false, Ordering::Release);
@@ -701,8 +735,11 @@ impl PyType {
701735
}
702736

703737
pub fn set_attr(&self, attr_name: &'static PyStrInterned, value: PyObjectRef) {
704-
self.attributes.write().insert(attr_name, value);
738+
// Invalidate caches BEFORE modifying attributes so that cached
739+
// descriptor pointers are still alive when type_cache_clear_version
740+
// drops the cache's strong references.
705741
self.modified();
742+
self.attributes.write().insert(attr_name, value);
706743
}
707744

708745
/// Internal get_attr implementation for fast lookup on a class.
@@ -718,41 +755,44 @@ impl PyType {
718755
/// find_name_in_mro with method cache (MCACHE).
719756
/// Looks in tp_dict of types in MRO, bypasses descriptors.
720757
///
721-
/// Uses a lock-free SeqLock pattern keyed by version:
722-
/// Read: load version → check name → load value → clone → re-check version
723-
/// Write: version=0 → swap value → set name → version=assigned
758+
/// Uses a lock-free SeqLock-style pattern:
759+
/// Read: load sequence/version/name → load value + try_to_owned →
760+
/// validate value pointer + sequence
761+
/// Write: sequence(begin) → version=0 → swap value/name → version=assigned → sequence(end)
724762
fn find_name_in_mro(&self, name: &'static PyStrInterned) -> Option<PyObjectRef> {
725763
let version = self.tp_version_tag.load(Ordering::Acquire);
726764
if version != 0 {
727765
let idx = type_cache_hash(version, name);
728766
let entry = &TYPE_CACHE[idx];
729-
let v1 = entry.version.load(Ordering::Acquire);
730-
if v1 == version
731-
&& core::ptr::eq(
732-
entry.name.load(Ordering::Relaxed),
733-
name as *const _ as *mut _,
734-
)
735-
{
767+
let name_ptr = name as *const _ as *mut _;
768+
loop {
769+
let seq1 = entry.begin_read();
770+
let v1 = entry.version.load(Ordering::Acquire);
771+
let type_version = self.tp_version_tag.load(Ordering::Acquire);
772+
if v1 != type_version
773+
|| !core::ptr::eq(entry.name.load(Ordering::Relaxed), name_ptr)
774+
{
775+
break;
776+
}
736777
let ptr = entry.value.load(Ordering::Acquire);
737-
if !ptr.is_null() {
738-
// SAFETY: The value pointer was stored via PyObjectRef::into_raw
739-
// and is valid as long as the version hasn't changed. We create
740-
// a temporary reference (ManuallyDrop prevents decrement), clone
741-
// it to get our own strong reference, then re-check the version
742-
// to confirm the entry wasn't invalidated during our read.
743-
let cloned = unsafe {
744-
let tmp = core::mem::ManuallyDrop::new(PyObjectRef::from_raw(
745-
NonNull::new_unchecked(ptr),
746-
));
747-
(*tmp).clone()
748-
};
749-
// SeqLock validation: if version changed, discard our clone
750-
let v2 = entry.version.load(Ordering::Acquire);
751-
if v2 == v1 {
778+
if ptr.is_null() {
779+
if entry.end_read(seq1) {
780+
break;
781+
}
782+
continue;
783+
}
784+
// _Py_TryIncrefCompare-style validation:
785+
// safe_inc, then ensure the source pointer is unchanged.
786+
let obj: &PyObject = unsafe { &*ptr };
787+
if let Some(cloned) = obj.try_to_owned() {
788+
let same_ptr = core::ptr::eq(entry.value.load(Ordering::Relaxed), ptr);
789+
if same_ptr && entry.end_read(seq1) {
752790
return Some(cloned);
753791
}
754792
drop(cloned);
793+
continue;
755794
}
795+
break;
756796
}
757797
}
758798

@@ -777,16 +817,17 @@ impl PyType {
777817
{
778818
let idx = type_cache_hash(assigned, name);
779819
let entry = &TYPE_CACHE[idx];
820+
let name_ptr = name as *const _ as *mut _;
821+
entry.begin_write();
780822
// Invalidate first to prevent readers from seeing partial state
781823
entry.version.store(0, Ordering::Release);
782824
// Swap in new value (refcount held by cache)
783825
let new_ptr = found.clone().into_raw().as_ptr();
784826
let old_ptr = entry.value.swap(new_ptr, Ordering::Relaxed);
785-
entry
786-
.name
787-
.store(name as *const _ as *mut _, Ordering::Relaxed);
827+
entry.name.store(name_ptr, Ordering::Relaxed);
788828
// Activate entry — Release ensures value/name writes are visible
789829
entry.version.store(assigned, Ordering::Release);
830+
entry.end_write();
790831
// Drop previous occupant (its version was already invalidated)
791832
if !old_ptr.is_null() {
792833
unsafe {
@@ -832,20 +873,24 @@ impl PyType {
832873
if version != 0 {
833874
let idx = type_cache_hash(version, name);
834875
let entry = &TYPE_CACHE[idx];
835-
let v1 = entry.version.load(Ordering::Acquire);
836-
if v1 == version
837-
&& core::ptr::eq(
838-
entry.name.load(Ordering::Relaxed),
839-
name as *const _ as *mut _,
840-
)
841-
{
876+
let name_ptr = name as *const _ as *mut _;
877+
loop {
878+
let seq1 = entry.begin_read();
879+
let v1 = entry.version.load(Ordering::Acquire);
880+
let type_version = self.tp_version_tag.load(Ordering::Acquire);
881+
if v1 != type_version
882+
|| !core::ptr::eq(entry.name.load(Ordering::Relaxed), name_ptr)
883+
{
884+
break;
885+
}
842886
let ptr = entry.value.load(Ordering::Acquire);
843-
if !ptr.is_null() {
844-
let v2 = entry.version.load(Ordering::Acquire);
845-
if v2 == v1 {
887+
if entry.end_read(seq1) {
888+
if !ptr.is_null() {
846889
return true;
847890
}
891+
break;
848892
}
893+
continue;
849894
}
850895
}
851896

@@ -1498,8 +1543,8 @@ impl PyType {
14981543
PySetterValue::Assign(ref val) => {
14991544
let key = identifier!(vm, __type_params__);
15001545
self.check_set_special_type_attr(key, vm)?;
1501-
self.attributes.write().insert(key, val.clone().into());
15021546
self.modified();
1547+
self.attributes.write().insert(key, val.clone().into());
15031548
}
15041549
PySetterValue::Delete => {
15051550
// For delete, we still need to check if the type is immutable
@@ -1510,8 +1555,8 @@ impl PyType {
15101555
)));
15111556
}
15121557
let key = identifier!(vm, __type_params__);
1513-
self.attributes.write().shift_remove(&key);
15141558
self.modified();
1559+
self.attributes.write().shift_remove(&key);
15151560
}
15161561
}
15171562
Ok(())

0 commit comments

Comments
 (0)