Skip to content

Commit bd1c863

Browse files
hawkinspcharris
authored andcommitted
BUG: Fix missing check for PyErr_Occurred() in _pyarray_correlate. (#28898)
When running the scipy 1.15 test suite test signal/tests/test_signaltools.py::test_lfilter_bad_object, with Python built in debug mode, we see the following error: ``` Fatal Python error: _Py_CheckSlotResult: Slot * of type float succeeded with an exception set ``` `None` ends up as the first argument to `dot`, and this triggers an error from PyFloat_Multiply. Once an error has occurred, we must avoid calling multiply again, since it asserts that PyErr_Occurred() is false if the output is a non-error, which will fail if an error was set at entry.
1 parent 9e50659 commit bd1c863

1 file changed

Lines changed: 8 additions & 2 deletions

File tree

numpy/_core/src/multiarray/multiarraymodule.c

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1215,6 +1215,7 @@ _pyarray_correlate(PyArrayObject *ap1, PyArrayObject *ap2, int typenum,
12151215
goto clean_ret;
12161216
}
12171217

1218+
int needs_pyapi = PyDataType_FLAGCHK(PyArray_DESCR(ret), NPY_NEEDS_PYAPI);
12181219
NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(ret));
12191220
is1 = PyArray_STRIDES(ap1)[0];
12201221
is2 = PyArray_STRIDES(ap2)[0];
@@ -1225,6 +1226,9 @@ _pyarray_correlate(PyArrayObject *ap1, PyArrayObject *ap2, int typenum,
12251226
n = n - n_left;
12261227
for (i = 0; i < n_left; i++) {
12271228
dot(ip1, is1, ip2, is2, op, n, ret);
1229+
if (needs_pyapi && PyErr_Occurred()) {
1230+
goto done;
1231+
}
12281232
n++;
12291233
ip2 -= is2;
12301234
op += os;
@@ -1236,19 +1240,21 @@ _pyarray_correlate(PyArrayObject *ap1, PyArrayObject *ap2, int typenum,
12361240
op += os * (n1 - n2 + 1);
12371241
}
12381242
else {
1239-
for (i = 0; i < (n1 - n2 + 1); i++) {
1243+
for (i = 0; i < (n1 - n2 + 1) && (!needs_pyapi || !PyErr_Occurred());
1244+
i++) {
12401245
dot(ip1, is1, ip2, is2, op, n, ret);
12411246
ip1 += is1;
12421247
op += os;
12431248
}
12441249
}
1245-
for (i = 0; i < n_right; i++) {
1250+
for (i = 0; i < n_right && (!needs_pyapi || !PyErr_Occurred()); i++) {
12461251
n--;
12471252
dot(ip1, is1, ip2, is2, op, n, ret);
12481253
ip1 += is1;
12491254
op += os;
12501255
}
12511256

1257+
done:
12521258
NPY_END_THREADS_DESCR(PyArray_DESCR(ret));
12531259
if (PyErr_Occurred()) {
12541260
goto clean_ret;

0 commit comments

Comments
 (0)