Add cuSOLVER path for torch.linalg.lstsq#57317
Add cuSOLVER path for torch.linalg.lstsq#57317IvanYashchuk wants to merge 7 commits intogh/ivanyashchuk/29/basefrom
Conversation
This PR implements QR-based least squares solver using geqrf, ormqr, and triangular_solve operations. Internal code of triangular_solve was fixed to handle correctly larger sized rectangular arrays. [ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit b105517 (more details on the Dr. CI page):
2 failures not recognized by patterns:
This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions to the (internal) Dr. CI Users group. |
This PR implements QR-based least squares solver using geqrf, ormqr, and triangular_solve operations. Internal code of triangular_solve was fixed to handle correctly larger sized rectangular arrays. [ghstack-poisoned]
|
@xwang233 and @lezcano and/or @nikitaved, would you review this, please? |
lezcano
left a comment
There was a problem hiding this comment.
LGTM! The logic is as clean as it can be. I just left a small comment on a bit that I found slightly more difficult to understand.
| :attr:`driver` chooses the LAPACK/MAGMA function that will be used. | ||
| For CPU inputs the valid values are `'gels'`, `'gelsy'`, `'gelsd`, `'gelss'`. | ||
| For CUDA input, the only valid driver is `'gels'`, which assumes that :attr:`A` is full-rank and `m < n`. | ||
| For CUDA input, the only valid driver is `'gels'`, which assumes that :attr:`A` is full-rank. |
| const_cast<Tensor&>(infos), | ||
| upper, transpose, conjugate_transpose, unitriangular); | ||
|
|
||
| B.narrow(-2, m, n - m).zero_(); |
There was a problem hiding this comment.
This is because triangular_solve_kernel writes its output into the first m elements of B, right? Could you leave a comment explaining this here?
xwang233
left a comment
There was a problem hiding this comment.
LGTM. Thanks for the PR!
| # cases m < n are only supported on CPU and for cuSOLVER path on CUDA | ||
| m_l_n_sizes = [(m // 2, m) for m in ms] | ||
| matrix_sizes = m_ge_n_sizes + (m_l_n_sizes if device == 'cpu' else []) | ||
| matrix_sizes = m_ge_n_sizes + (m_l_n_sizes if cusolver_available else []) |
There was a problem hiding this comment.
maybe use (cusolver_available or device == 'cpu') to test both?
This PR implements QR-based least squares solver using geqrf, ormqr, and triangular_solve operations. Internal code of triangular_solve was fixed to handle correctly larger sized rectangular arrays. [ghstack-poisoned]
This PR implements QR-based least squares solver using geqrf, ormqr, and triangular_solve operations. Internal code of triangular_solve was fixed to handle correctly larger sized rectangular arrays. [ghstack-poisoned]
This PR implements QR-based least squares solver using geqrf, ormqr, and triangular_solve operations. Internal code of triangular_solve was fixed to handle correctly larger sized rectangular arrays. [ghstack-poisoned]
This PR implements QR-based least squares solver using geqrf, ormqr, and triangular_solve operations. Internal code of triangular_solve was fixed to handle correctly larger sized rectangular arrays. [ghstack-poisoned]
This PR implements QR-based least squares solver using geqrf, ormqr, and triangular_solve operations. Internal code of triangular_solve was fixed to handle correctly larger sized rectangular arrays. ghstack-source-id: e7d2246 Pull Request resolved: pytorch#57317
mruberry
left a comment
There was a problem hiding this comment.
Nice work all! Thanks for reviewing, @nikitaved, @xwang233
|
@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
1 similar comment
|
@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
|
Reverting this PR because it broke one of the Windows test jobs: https://app.circleci.com/pipelines/github/pytorch/pytorch/317376/workflows/463399f8-78ef-4894-a9bf-8b666943efc2/jobs/13217419 |
|
This pull request has been reverted by 72ebdd6. |
|
This diff was revert, but the previous commits in the stack were not, I think. Link to why it was reverted: It broke pytorch_windows_vs2019_py36_cuda10.1_test2 and tests test_linalg_lstsq_input_checks_cuda_complex128, test_linalg_lstsq_input_checks_cuda_complex64, test_linalg_lstsq_input_checks_cuda_float32, and test_linalg_lstsq_input_checks_cuda_float64. Sample failure snippet: The easiest way to reland the rest of the stack is probably to rebase the uncommitted PRs on nightly with the fix. We can run the updated PR through ci/all to validate this build is fixed, too. |
This PR implements QR-based least squares solver using geqrf, ormqr, and triangular_solve operations. Internal code of triangular_solve was fixed to handle correctly larger sized rectangular arrays. Differential Revision: [D28242069](https://our.internmc.facebook.com/intern/diff/D28242069) [ghstack-poisoned]
|
@mruberry, I fixed the problem with that Windows CUDA 10.1 build. Here is the ci-all PR #57816. The problem was that the condition of cuSOLVER availability was not correct in the test. I think we should consider adding a more robust way to check from Python whether cuSOLVER is used in PyTorch. We use cuSOLVER if CUDA version is >= 10.1.243, but |
Summary: Pull Request resolved: pytorch#57317 This PR implements QR-based least squares solver using geqrf, ormqr, and triangular_solve operations. Internal code of triangular_solve was fixed to handle correctly larger sized rectangular arrays. Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D28242069 Pulled By: mruberry fbshipit-source-id: 23979d19ccc7f591afa8df4435d0db847e2d0d97
Thanks @IvanYashchuk, and thanks for the thorough analysis. So users with a CUDA version between 10.1 and 10.1.243 will get the correct behavior (we think), but our test suite will report the behavior as incorrect? |
|
@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
Yes, but our test suite doesn't test the behavior for these versions, the tests will pass. |
Summary: Pull Request resolved: pytorch#57317 This PR implements QR-based least squares solver using geqrf, ormqr, and triangular_solve operations. Internal code of triangular_solve was fixed to handle correctly larger sized rectangular arrays. Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D28242069 Pulled By: mruberry fbshipit-source-id: 23979d19ccc7f591afa8df4435d0db847e2d0d97
Summary: Pull Request resolved: pytorch#57317 This PR implements QR-based least squares solver using geqrf, ormqr, and triangular_solve operations. Internal code of triangular_solve was fixed to handle correctly larger sized rectangular arrays. Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D28312683 Pulled By: mruberry fbshipit-source-id: dc8ae837a5fb0685d85c8733a47d7d25dc46443a
Summary: Pull Request resolved: pytorch#57317 This PR implements QR-based least squares solver using geqrf, ormqr, and triangular_solve operations. Internal code of triangular_solve was fixed to handle correctly larger sized rectangular arrays. Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D28242069 Pulled By: mruberry fbshipit-source-id: 23979d19ccc7f591afa8df4435d0db847e2d0d97
Summary: Pull Request resolved: pytorch#57317 This PR implements QR-based least squares solver using geqrf, ormqr, and triangular_solve operations. Internal code of triangular_solve was fixed to handle correctly larger sized rectangular arrays. Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D28312683 Pulled By: mruberry fbshipit-source-id: dc8ae837a5fb0685d85c8733a47d7d25dc46443a
Stack from ghstack:
This PR implements QR-based least squares solver using geqrf, ormqr, and
triangular_solve operations.
Internal code of triangular_solve was fixed to handle correctly larger
sized rectangular arrays.
Differential Revision: D28312683