Skip to content

Commit ec25df2

Browse files
Copilotyouknowone
andcommitted
Implement Unix SemLock and expose sem_unlink
Co-authored-by: youknowone <69878+youknowone@users.noreply.github.com>
1 parent b7fcc92 commit ec25df2

File tree

1 file changed

+69
-60
lines changed

1 file changed

+69
-60
lines changed

crates/stdlib/src/multiprocessing.rs

Lines changed: 69 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,9 @@ mod _multiprocessing {
4545
#[pymodule]
4646
mod _multiprocessing {
4747
use crate::vm::{
48-
Context, FromArgs, Py, PyBaseExceptionRef, PyPayload, PyRef, PyResult, VirtualMachine,
49-
builtins::PyTypeRef,
50-
function::OptionalArg,
48+
Context, FromArgs, Py, PyPayload, PyResult, VirtualMachine,
49+
builtins::{PyBaseExceptionRef, PyType, PyTypeRef},
50+
function::{FuncArgs, OptionalArg},
5151
types::Constructor,
5252
};
5353
use libc::sem_t;
@@ -58,19 +58,21 @@ mod _multiprocessing {
5858
time::Duration,
5959
};
6060

61-
const RECURSIVE_MUTEX: i32 = 0;
61+
const RECURSIVE_MUTEX_KIND: i32 = 0;
62+
const SEMAPHORE_KIND: i32 = 1;
63+
const SEM_VALUE_MAX_CONST: i32 = 32_767;
6264

6365
#[derive(FromArgs)]
6466
struct SemLockArgs {
65-
#[pyarg(positional_only)]
67+
#[pyarg(positional)]
6668
kind: i32,
67-
#[pyarg(positional_only)]
69+
#[pyarg(positional)]
6870
value: i32,
69-
#[pyarg(positional_only)]
71+
#[pyarg(positional)]
7072
maxvalue: i32,
71-
#[pyarg(positional_only)]
73+
#[pyarg(positional)]
7274
name: String,
73-
#[pyarg(positional_only)]
75+
#[pyarg(positional)]
7476
unlink: bool,
7577
}
7678

@@ -111,12 +113,7 @@ mod _multiprocessing {
111113
) -> PyResult<(Self, Option<String>)> {
112114
let cname = semaphore_name(vm, name)?;
113115
let raw = unsafe {
114-
libc::sem_open(
115-
cname.as_ptr(),
116-
libc::O_CREAT | libc::O_EXCL,
117-
0o600,
118-
value,
119-
)
116+
libc::sem_open(cname.as_ptr(), libc::O_CREAT | libc::O_EXCL, 0o600, value)
120117
};
121118
if raw == libc::SEM_FAILED {
122119
let err = Errno::last();
@@ -158,37 +155,8 @@ mod _multiprocessing {
158155
}
159156
}
160157

161-
#[extend_class]
162-
fn extend_class(ctx: &Context, class: &'static Py<crate::vm::builtins::PyType>) {
163-
class.set_attr(
164-
"SEM_VALUE_MAX",
165-
ctx.new_int(libc::SEM_VALUE_MAX),
166-
ctx,
167-
);
168-
}
169-
170158
#[pyclass(with(Constructor))]
171159
impl SemLock {
172-
#[pyslot]
173-
fn slot_new(cls: PyTypeRef, args: SemLockArgs, vm: &VirtualMachine) -> PyResult {
174-
if args.value < 0 || args.value > args.maxvalue {
175-
return Err(vm.new_value_error("semaphore or lock value out of range".to_owned()));
176-
}
177-
let value = u32::try_from(args.value).map_err(|_| {
178-
vm.new_value_error("semaphore or lock value out of range".to_owned())
179-
})?;
180-
let (handle, name) = SemHandle::create(&args.name, value, args.unlink, vm)?;
181-
let zelf = SemLock {
182-
handle,
183-
kind: args.kind,
184-
maxvalue: args.maxvalue,
185-
name,
186-
owner: AtomicU64::new(0),
187-
count: AtomicUsize::new(0),
188-
};
189-
zelf.into_ref_with_type(vm, cls).map(Into::into)
190-
}
191-
192160
#[pygetset]
193161
fn handle(&self) -> isize {
194162
self.handle.as_ptr() as isize
@@ -223,9 +191,7 @@ mod _multiprocessing {
223191
}
224192

225193
let tid = current_thread_id();
226-
if self.kind == RECURSIVE_MUTEX
227-
&& self.owner.load(Ordering::Acquire) == tid
228-
{
194+
if self.kind == RECURSIVE_MUTEX_KIND && self.owner.load(Ordering::Acquire) == tid {
229195
self.count.fetch_add(1, Ordering::Relaxed);
230196
return Ok(true);
231197
}
@@ -254,16 +220,17 @@ mod _multiprocessing {
254220
#[pymethod]
255221
fn release(&self, vm: &VirtualMachine) -> PyResult<()> {
256222
let tid = current_thread_id();
257-
if self.kind == RECURSIVE_MUTEX && self.owner.load(Ordering::Acquire) != tid {
223+
if self.kind == RECURSIVE_MUTEX_KIND && self.owner.load(Ordering::Acquire) != tid {
258224
return Err(vm.new_value_error("cannot release un-acquired lock".to_owned()));
259225
}
260226

261-
if self.owner.load(Ordering::Acquire) == tid {
227+
let owner_tid = self.owner.load(Ordering::Acquire);
228+
if owner_tid == tid {
262229
let current = self.count.load(Ordering::Acquire);
263230
if current == 0 {
264231
return Err(vm.new_value_error("cannot release un-acquired lock".to_owned()));
265232
}
266-
if self.kind == RECURSIVE_MUTEX && current > 1 {
233+
if self.kind == RECURSIVE_MUTEX_KIND && current > 1 {
267234
self.count.store(current - 1, Ordering::Release);
268235
return Ok(());
269236
}
@@ -272,6 +239,11 @@ mod _multiprocessing {
272239
if new_val == 0 {
273240
self.owner.store(0, Ordering::Release);
274241
}
242+
} else if self.kind != RECURSIVE_MUTEX_KIND {
243+
// releasing semaphore or non-recursive lock from another thread;
244+
// drop ownership information.
245+
self.owner.store(0, Ordering::Release);
246+
self.count.store(0, Ordering::Release);
275247
}
276248

277249
let res = unsafe { libc::sem_post(self.handle.as_ptr()) };
@@ -282,8 +254,7 @@ mod _multiprocessing {
282254
Ok(())
283255
}
284256

285-
#[pymethod]
286-
#[pyclass(name = "__enter__")]
257+
#[pymethod(name = "__enter__")]
287258
fn enter(&self, vm: &VirtualMachine) -> PyResult<bool> {
288259
self.acquire(
289260
AcquireArgs {
@@ -295,14 +266,15 @@ mod _multiprocessing {
295266
}
296267

297268
#[pymethod]
298-
fn __exit__(&self, _args: crate::vm::function::FuncArgs, vm: &VirtualMachine) -> PyResult<()> {
269+
fn __exit__(&self, _args: FuncArgs, vm: &VirtualMachine) -> PyResult<()> {
299270
self.release(vm)
300271
}
301272

273+
#[pyclassmethod]
302274
#[pymethod(name = "_rebuild")]
303275
fn rebuild(
304276
cls: PyTypeRef,
305-
handle: isize,
277+
_handle: isize,
306278
kind: i32,
307279
maxvalue: i32,
308280
name: Option<String>,
@@ -320,8 +292,6 @@ mod _multiprocessing {
320292
owner: AtomicU64::new(0),
321293
count: AtomicUsize::new(0),
322294
};
323-
// handle is unused but kept for compatibility
324-
let _ = handle;
325295
zelf.into_ref_with_type(vm, cls).map(Into::into)
326296
}
327297

@@ -362,6 +332,24 @@ mod _multiprocessing {
362332
}
363333
}
364334

335+
#[extend_class]
336+
fn extend_class(ctx: &Context, class: &Py<PyType>) {
337+
class.set_attr(
338+
ctx.interned_str("RECURSIVE_MUTEX")
339+
.expect("intern RECURSIVE_MUTEX"),
340+
ctx.new_int(RECURSIVE_MUTEX_KIND).into(),
341+
);
342+
class.set_attr(
343+
ctx.interned_str("SEMAPHORE").expect("intern SEMAPHORE"),
344+
ctx.new_int(SEMAPHORE_KIND).into(),
345+
);
346+
class.set_attr(
347+
ctx.interned_str("SEM_VALUE_MAX")
348+
.expect("intern SEM_VALUE_MAX"),
349+
ctx.new_int(SEM_VALUE_MAX_CONST).into(),
350+
);
351+
}
352+
365353
fn wait(&self, vm: &VirtualMachine) -> PyResult<()> {
366354
loop {
367355
let res = unsafe { libc::sem_wait(self.handle.as_ptr()) };
@@ -390,8 +378,7 @@ mod _multiprocessing {
390378

391379
fn wait_timeout(&self, duration: Duration, vm: &VirtualMachine) -> PyResult<bool> {
392380
let mut ts = current_timespec(vm)?;
393-
let nsec_total =
394-
ts.tv_nsec as i64 + i64::from(duration.subsec_nanos());
381+
let nsec_total = ts.tv_nsec as i64 + i64::from(duration.subsec_nanos());
395382
ts.tv_sec = ts
396383
.tv_sec
397384
.saturating_add(duration.as_secs() as libc::time_t + nsec_total / 1_000_000_000);
@@ -411,6 +398,28 @@ mod _multiprocessing {
411398
}
412399
}
413400

401+
impl Constructor for SemLock {
402+
type Args = SemLockArgs;
403+
404+
fn py_new(_cls: &Py<PyType>, args: Self::Args, vm: &VirtualMachine) -> PyResult<Self> {
405+
if args.value < 0 || args.value > args.maxvalue {
406+
return Err(vm.new_value_error("semaphore or lock value out of range".to_owned()));
407+
}
408+
let value = u32::try_from(args.value).map_err(|_| {
409+
vm.new_value_error("semaphore or lock value out of range".to_owned())
410+
})?;
411+
let (handle, name) = SemHandle::create(&args.name, value, args.unlink, vm)?;
412+
Ok(SemLock {
413+
handle,
414+
kind: args.kind,
415+
maxvalue: args.maxvalue,
416+
name,
417+
owner: AtomicU64::new(0),
418+
count: AtomicUsize::new(0),
419+
})
420+
}
421+
}
422+
414423
#[pyfunction]
415424
fn sem_unlink(name: String, vm: &VirtualMachine) -> PyResult<()> {
416425
let cname = semaphore_name(vm, &name)?;
@@ -461,11 +470,11 @@ mod _multiprocessing {
461470
};
462471
let text = msg.unwrap_or_else(|| err.desc().to_owned());
463472
vm.new_os_subtype_error(exc_type, Some(err as i32), text)
464-
.into()
473+
.upcast()
465474
}
466475

467476
fn current_thread_id() -> u64 {
468-
std::thread::current().id().as_u64().get()
477+
unsafe { libc::pthread_self() as u64 }
469478
}
470479
}
471480

0 commit comments

Comments
 (0)