Batched linear system of equations solver (torch.bgesv)#4502
Batched linear system of equations solver (torch.bgesv)#4502zou3519 wants to merge 7 commits intopytorch:masterfrom
Conversation
|
cc @fritzo |
apaszke
left a comment
There was a problem hiding this comment.
That looks great, thanks for wrapping this up so quickly! Only a few minor comments.
| } | ||
|
|
||
| magma_queue_t magma_queue; | ||
| magma_queue_create_from_cuda( |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| THError("MAGMA bgesv (gesv_batched) : For batch number %lld: U(%d,%d) is zero, singular U.", | ||
| (long long)batch_count, info, info); | ||
| } | ||
| } |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| THTensor_(resizeNd)(self, 3, size, stride); | ||
| THTensor_(copy)(self, src); | ||
| return self; | ||
| } |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| x_exp, LU_exp = torch.gesv(b.squeeze(0), A.squeeze(0)) | ||
| x, LU = torch.bgesv(b, A) | ||
| self.assertEqual(x, x_exp.unsqueeze(0)) | ||
| self.assertEqual(LU, LU_exp.unsqueeze(0)) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
- Refactored THCTensor_(newBatchedColumnMajor) and THTensor_(cloneBatchedColumnMajor) - In THCTensor_(bgesv), move error checking to after freeing a lot of things - Added a test to check bgesv against gesv (in a loop) for a batch of size 4
|
A few high-level comments: The equivalent NumPy function is np.linalg.solve
You don't have to block the PR on these suggestions, but if it's not too much extra work see if you can merge gesv and bgesv so that we don't unnecessarily expand the public API. |
|
Closing this in favor of #4612 |
Fixes pytorch#3164 Picks up from pytorch#4502 I moved `gesv` to ATen. Adds bindings for MAGMA's `gesv_batched` function for CUDA. For CPU, runs `THLapack(gesv)` in a for loop. The new function supports arbitrary batch dimensions (and broadcasting of those dimensions). For example, the 4-d tensor `A x B x M x M` should be treated as having batch-size `(A x B)`. The overhead of creating the magma_queue_t is: ~350000 microseconds the first time it's called and ~6 microseconds every time after that.
* Add batched linear solver to torch.gesv() Fixes #3164 Picks up from #4502 I moved `gesv` to ATen. Adds bindings for MAGMA's `gesv_batched` function for CUDA. For CPU, runs `THLapack(gesv)` in a for loop. The new function supports arbitrary batch dimensions (and broadcasting of those dimensions). For example, the 4-d tensor `A x B x M x M` should be treated as having batch-size `(A x B)`. The overhead of creating the magma_queue_t is: ~350000 microseconds the first time it's called and ~6 microseconds every time after that. * Tests and docs * Address comments * Address comments * Rebase * Address comments * Fix rebase * Addressed comments * Address comments * Address comments * Addressed comments
* Add batched linear solver to torch.gesv() Fixes pytorch#3164 Picks up from pytorch#4502 I moved `gesv` to ATen. Adds bindings for MAGMA's `gesv_batched` function for CUDA. For CPU, runs `THLapack(gesv)` in a for loop. The new function supports arbitrary batch dimensions (and broadcasting of those dimensions). For example, the 4-d tensor `A x B x M x M` should be treated as having batch-size `(A x B)`. The overhead of creating the magma_queue_t is: ~350000 microseconds the first time it's called and ~6 microseconds every time after that. * Tests and docs * Address comments * Address comments * Rebase * Address comments * Fix rebase * Addressed comments * Address comments * Address comments * Addressed comments
* Add batched linear solver to torch.gesv() Fixes pytorch#3164 Picks up from pytorch#4502 I moved `gesv` to ATen. Adds bindings for MAGMA's `gesv_batched` function for CUDA. For CPU, runs `THLapack(gesv)` in a for loop. The new function supports arbitrary batch dimensions (and broadcasting of those dimensions). For example, the 4-d tensor `A x B x M x M` should be treated as having batch-size `(A x B)`. The overhead of creating the magma_queue_t is: ~350000 microseconds the first time it's called and ~6 microseconds every time after that. * Tests and docs * Address comments * Address comments * Rebase * Address comments * Fix rebase * Addressed comments * Address comments * Address comments * Addressed comments
Implements torch.bgesv, a batched linear system of equations solver.
Adds bindings for MAGMA's gesv_batched function for CUDA.
For CPU, runs THLapack(gesv) in a for loop.
I decided to not build this into
torch.gesvbut if we want to I can change that.cc @apaszke
Test Plan
New unit tests:
torch.gesvand specifying outputs