Skip to content

Commit 800dcd4

Browse files
committed
more threading fix
1 parent b89cca3 commit 800dcd4

File tree

2 files changed

+63
-6
lines changed

2 files changed

+63
-6
lines changed

Lib/test/test_threading.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,8 @@ def run(self):
343343
t.join()
344344
# else the thread is still running, and we have no way to kill it
345345

346+
# TODO: RUSTPYTHON: threading._start_new_thread is not exposed
347+
@unittest.skip("TODO: RUSTPYTHON; threading._start_new_thread not exposed")
346348
def test_limbo_cleanup(self):
347349
# Issue 7481: Failure to start thread should cleanup the limbo map.
348350
def fail_new_thread(*args):
@@ -742,6 +744,8 @@ def f():
742744
rc, out, err = assert_python_ok("-c", code)
743745
self.assertEqual(err, b"")
744746

747+
# TODO: RUSTPYTHON: Thread._tstate_lock is a CPython implementation detail
748+
@unittest.skip("TODO: RUSTPYTHON; Thread._tstate_lock not implemented")
745749
def test_tstate_lock(self):
746750
# Test an implementation detail of Thread objects.
747751
started = _thread.allocate_lock()

crates/vm/src/stdlib/thread.rs

Lines changed: 59 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ pub(crate) mod _thread {
1919
lock_api::{RawMutex as RawMutexT, RawMutexTimed, RawReentrantMutex},
2020
};
2121
use std::thread;
22-
use thread_local::ThreadLocal;
2322

2423
// PYTHREAD_NAME: show current thread name
2524
pub const PYTHREAD_NAME: Option<&str> = {
@@ -575,23 +574,77 @@ pub(crate) mod _thread {
575574
Ok(())
576575
}
577576

577+
// Thread-local storage for cleanup guards
578+
// When a thread terminates, the guard is dropped, which triggers cleanup
579+
thread_local! {
580+
static LOCAL_GUARDS: std::cell::RefCell<Vec<LocalGuard>> = const { std::cell::RefCell::new(Vec::new()) };
581+
}
582+
583+
// Guard that removes thread-local data when dropped
584+
struct LocalGuard {
585+
local: std::sync::Weak<LocalData>,
586+
thread_id: std::thread::ThreadId,
587+
}
588+
589+
impl Drop for LocalGuard {
590+
fn drop(&mut self) {
591+
// eprintln!("[DEBUG] LocalGuard::drop called for thread {:?}", self.thread_id);
592+
if let Some(local_data) = self.local.upgrade() {
593+
local_data.data.lock().remove(&self.thread_id);
594+
}
595+
}
596+
}
597+
598+
// Shared data structure for Local
599+
struct LocalData {
600+
data: parking_lot::Mutex<std::collections::HashMap<std::thread::ThreadId, PyDictRef>>,
601+
}
602+
603+
impl std::fmt::Debug for LocalData {
604+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
605+
f.debug_struct("LocalData").finish_non_exhaustive()
606+
}
607+
}
608+
578609
#[pyattr]
579610
#[pyclass(module = "thread", name = "_local")]
580611
#[derive(Debug, PyPayload)]
581612
struct Local {
582-
data: ThreadLocal<PyDictRef>,
613+
inner: std::sync::Arc<LocalData>,
583614
}
584615

585616
#[pyclass(with(GetAttr, SetAttr), flags(BASETYPE))]
586617
impl Local {
587618
fn l_dict(&self, vm: &VirtualMachine) -> PyDictRef {
588-
self.data.get_or(|| vm.ctx.new_dict()).clone()
619+
let thread_id = std::thread::current().id();
620+
let mut data = self.inner.data.lock();
621+
622+
if let Some(dict) = data.get(&thread_id) {
623+
return dict.clone();
624+
}
625+
626+
// Create new dict for this thread
627+
let dict = vm.ctx.new_dict();
628+
data.insert(thread_id, dict.clone());
629+
630+
// Register cleanup guard for this thread
631+
let guard = LocalGuard {
632+
local: std::sync::Arc::downgrade(&self.inner),
633+
thread_id,
634+
};
635+
LOCAL_GUARDS.with(|guards| {
636+
guards.borrow_mut().push(guard);
637+
});
638+
639+
dict
589640
}
590641

591642
#[pyslot]
592643
fn slot_new(cls: PyTypeRef, _args: FuncArgs, vm: &VirtualMachine) -> PyResult {
593644
Self {
594-
data: ThreadLocal::new(),
645+
inner: std::sync::Arc::new(LocalData {
646+
data: parking_lot::Mutex::new(std::collections::HashMap::new()),
647+
}),
595648
}
596649
.into_ref_with_type(vm, cls)
597650
.map(Into::into)
@@ -706,8 +759,8 @@ pub(crate) mod _thread {
706759
}
707760

708761
#[pymethod]
709-
fn join(&self, timeout: OptionalArg<f64>, vm: &VirtualMachine) -> PyResult<()> {
710-
let _timeout_val = timeout.unwrap_or(-1.0);
762+
fn join(&self, timeout: OptionalArg<Option<f64>>, vm: &VirtualMachine) -> PyResult<()> {
763+
let _timeout_val = timeout.flatten().unwrap_or(-1.0);
711764

712765
// Check for self-join and if already joined
713766
let join_handle = {

0 commit comments

Comments
 (0)