Skip to content

Commit a3c489f

Browse files
committed
BUG: Ensure lstsq can handle RHS with all sizes.
This fixes a bug in the creation of workspace arrays for a call to `lapack_lite.zgelsd`, which led to segmentation faults when a RHS was passed in that had larger size than the size of the matrix.
1 parent b97f9b0 commit a3c489f

File tree

2 files changed

+13
-7
lines changed

2 files changed

+13
-7
lines changed

numpy/linalg/linalg.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2021,14 +2021,8 @@ def lstsq(a, b, rcond="warn"):
20212021
work = zeros((lwork,), t)
20222022
results = lapack_routine(m, n, n_rhs, a, m, bstar, ldb, s, rcond,
20232023
0, work, -1, rwork, iwork, 0)
2024-
lwork = int(abs(work[0]))
2025-
rwork = zeros((lwork,), real_t)
2026-
a_real = zeros((m, n), real_t)
2027-
bstar_real = zeros((ldb, n_rhs,), real_t)
2028-
results = lapack_lite.dgelsd(m, n, n_rhs, a_real, m,
2029-
bstar_real, ldb, s, rcond,
2030-
0, rwork, -1, iwork, 0)
20312024
lrwork = int(rwork[0])
2025+
lwork = int(work[0].real)
20322026
work = zeros((lwork,), t)
20332027
rwork = zeros((lrwork,), real_t)
20342028
results = lapack_routine(m, n, n_rhs, a, m, bstar, ldb, s, rcond,

numpy/linalg/tests/test_regression.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,18 @@ def test_norm_object_array(self):
137137
assert_raises(TypeError, linalg.norm, testmatrix, ord=-2)
138138
assert_raises(ValueError, linalg.norm, testmatrix, ord=3)
139139

140+
def test_lstsq_complex_larger_rhs(self):
141+
# gh-9891
142+
size = 20
143+
n_rhs = 70
144+
G = np.random.randn(size, size) + 1j * np.random.randn(size, size)
145+
u = np.random.randn(size, n_rhs) + 1j * np.random.randn(size, n_rhs)
146+
b = G.dot(u)
147+
# This should work without segmentation fault.
148+
u_lstsq, res, rank, sv = linalg.lstsq(G, b, rcond=None)
149+
# check results just in case
150+
assert_array_almost_equal(u_lstsq, u)
151+
140152

141153
if __name__ == '__main__':
142154
run_module_suite()

0 commit comments

Comments
 (0)