Skip to content

Commit 7ff9863

Browse files
authored
Merge pull request #30736 from charris/backport-30667
BUG: fix thread safety of `array_getbuffer` (#30667)
2 parents 18bdb2e + 431fffb commit 7ff9863

File tree

2 files changed

+35
-1
lines changed

2 files changed

+35
-1
lines changed

numpy/_core/src/multiarray/buffer.c

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -793,8 +793,10 @@ array_getbuffer(PyObject *obj, Py_buffer *view, int flags)
793793
}
794794

795795
/* Fill in information (and add it to _buffer_info if necessary) */
796+
Py_BEGIN_CRITICAL_SECTION(self);
796797
info = _buffer_get_info(
797798
&((PyArrayObject_fields *)self)->_buffer_info, obj, flags);
799+
Py_END_CRITICAL_SECTION();
798800
if (info == NULL) {
799801
goto fail;
800802
}
@@ -880,7 +882,10 @@ void_getbuffer(PyObject *self, Py_buffer *view, int flags)
880882
* to find the correct format. This format must also be stored, since
881883
* at least in theory it can change (in practice it should never change).
882884
*/
883-
_buffer_info_t *info = _buffer_get_info(&scalar->_buffer_info, self, flags);
885+
_buffer_info_t *info = NULL;
886+
Py_BEGIN_CRITICAL_SECTION(scalar);
887+
info = _buffer_get_info(&scalar->_buffer_info, self, flags);
888+
Py_END_CRITICAL_SECTION();
884889
if (info == NULL) {
885890
Py_DECREF(self);
886891
return -1;

numpy/_core/tests/test_multithreading.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import concurrent.futures
2+
import sys
23
import threading
34

45
import pytest
@@ -375,3 +376,31 @@ def replace_list_items(b):
375376
finally:
376377
if len(tasks) < 5:
377378
b.abort()
379+
380+
@pytest.mark.skipif(sys.version_info < (3, 12), reason="Python >= 3.12 required")
381+
def test_array__buffer__thread_safety():
382+
import inspect
383+
arr = np.arange(1000)
384+
flags = [inspect.BufferFlags.STRIDED, inspect.BufferFlags.READ]
385+
386+
def func(b):
387+
b.wait()
388+
for i in range(100):
389+
arr.__buffer__(flags[i % 2])
390+
391+
run_threaded(func, max_workers=8, pass_barrier=True)
392+
393+
@pytest.mark.skipif(sys.version_info < (3, 12), reason="Python >= 3.12 required")
394+
def test_void_dtype__buffer__thread_safety():
395+
import inspect
396+
dt = np.dtype([('name', np.str_, 16), ('grades', np.float64, (2,))])
397+
x = np.array(('ndarray_scalar', (1.2, 3.0)), dtype=dt)[()]
398+
assert isinstance(x, np.void)
399+
flags = [inspect.BufferFlags.STRIDES, inspect.BufferFlags.READ]
400+
401+
def func(b):
402+
b.wait()
403+
for i in range(100):
404+
x.__buffer__(flags[i % 2])
405+
406+
run_threaded(func, max_workers=8, pass_barrier=True)

0 commit comments

Comments
 (0)