[RFC, ready] Batched Inverse#9949
Conversation
I didn't know how getri worked after using getrf, which led to some issues
|
For a test, e.g. this style transfer code https://github.com/NVIDIA/FastPhotoStyle/blob/master/photo_smooth.py#L77 batch-inverts matrices of shape HxWx3x3 in NumPy, one can compare perf of batched GPU inversion versus NumPy. |
|
Oops, I didn't notice #9102 . Should I close this? |
| scalar_t wkopt; | ||
| Tensor work; | ||
|
|
||
| for (int64_t i = 0; i < batch_size; i++) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
|
||
| Tensor _inverse_helper_cpu(const Tensor& self) { | ||
| std::vector<int64_t> getrf_infos(batchCount(self), 0); | ||
| std::vector<int64_t> getri_infos(batchCount(self), 0); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| scalar_t** self_inv_array; | ||
|
|
||
| ALLOCATE_ARRAY(getrf_info_array, magma_int_t, batch_size, self); | ||
| ALLOCATE_ARRAY(getri_info_array, magma_int_t, batch_size, self); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| n, n, self_array, n, ipiv_array, getrf_info_array, | ||
| batch_size, createMagmaQueue(self)); | ||
|
|
||
| for (int64_t i = 0; i < batch_size; i++) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| n, self_array, n, ipiv_array, self_inv_array, | ||
| n, getri_info_array, batch_size, createMagmaQueue(self)); | ||
|
|
||
| for (int64_t i = 0; i < batch_size; i++) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
|
||
| magmaGetriBatched<scalar_t>( | ||
| n, self_array, n, ipiv_array, self_inv_array, | ||
| n, getri_info_array, batch_size, createMagmaQueue(self)); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| M = cast(torch.randn(5, 5)) | ||
| MI = torch.inverse(M) | ||
| E = torch.eye(5) | ||
| self.assertFalse(MI.is_contiguous(), 'MI is contiguous') |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
9abbad0 to
9a22157
Compare
| namespace native { | ||
|
|
||
| #ifdef USE_MAGMA | ||
| static magma_queue_t createMagmaQueue(const Tensor& tensor) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| } | ||
|
|
||
| // Because this is out-of-place inverse, the predefined macros will | ||
| // not work |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
zou3519
left a comment
There was a problem hiding this comment.
One minor comment, otherwise, lgtm!
|
There is an optimization available for CUDA batched |
|
@vishwakftw if it's a quick change feel free to do it here. Otherwise, it'll be easier to review as a separate PR |
- use magma_*getrf_smallsq_shfl for batches of matrices with dim <= 32 - remove destroyMagmaQueue and make magma queue more RAII-like
| return magma_queue; | ||
| } | ||
|
|
||
| static void destroyMagmaQueue(magma_queue_t& existing_queue) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| if (self.size(-2) <= 32) { | ||
| magmaGetrfSmallSquareBatched<scalar_t>( | ||
| n, self_array, n, ipiv_array, info_array, | ||
| batch_size, createMagmaQueue(self)); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
Seemingly, the optimization for small matrix getrf is done inside MAGMA's getrf function itself. I found this while inspecting the source code. I'll revert that change. |
| #ifdef USE_MAGMA | ||
|
|
||
| // RAII for a MAGMA Queue | ||
| struct MAGMAQueue { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
|
||
| magmaGetriBatched<scalar_t>( | ||
| n, self_array, n, ipiv_array, self_inv_array, | ||
| n, info_array, batch_size, MAGMAQueue(self.get_device())); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| void magmaGesvBatched( | ||
| magma_int_t n, magma_int_t nrhs, scalar_t** dA_array, magma_int_t ldda, | ||
| magma_int_t** dipiv_array, scalar_t** dB_array, magma_int_t lddb, | ||
| magma_int_t* dinfo_array, magma_int_t batch_count, MAGMAQueue magma_queue) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
zou3519
left a comment
There was a problem hiding this comment.
Looks great now! One last comment :) then we can merge this.
The test failures are unrelated.
| struct MAGMAQueue { | ||
|
|
||
| // Default constructor, does nothing. | ||
| MAGMAQueue() = default; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
facebook-github-bot
left a comment
There was a problem hiding this comment.
zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: Complete billing of changes: Related to Batch Inverse: - [x] Add batched inverse (CPU) - [x] Add batched inverse (CUDA) - [x] Modify autograd entry - [x] Add tests - [x] test_autograd - [x] test_cuda - [x] test_torch - [x] Modify docs - [x] Remove `_batch_inverse` in `MultivariateNormal`. - [x] Allow batch matrices as inputs for negative powers in `matrix_power` Miscellaneous modifications: - [x] Move all batch operations to BatchLinearAlgebra.cpp/.cu and provide general framework for adding more batch ops. - [x] Add a RAII structure for MAGMA queue management. Pull Request resolved: pytorch/pytorch#9949 Differential Revision: D10559089 Pulled By: zou3519 fbshipit-source-id: 7da24977f8a79d97dd42883302e13e708c1726e4
|
DTD_inv = torch.inverse(DTD + self.lambda3 * torch.eye(self.input_dim).cuda()) How to solve this error |
Summary: Complete billing of changes: Related to Batch Inverse: - [x] Add batched inverse (CPU) - [x] Add batched inverse (CUDA) - [x] Modify autograd entry - [x] Add tests - [x] test_autograd - [x] test_cuda - [x] test_torch - [x] Modify docs - [x] Remove `_batch_inverse` in `MultivariateNormal`. - [x] Allow batch matrices as inputs for negative powers in `matrix_power` Miscellaneous modifications: - [x] Move all batch operations to BatchLinearAlgebra.cpp/.cu and provide general framework for adding more batch ops. - [x] Add a RAII structure for MAGMA queue management. Pull Request resolved: pytorch#9949 Differential Revision: D10559089 Pulled By: zou3519 fbshipit-source-id: 7da24977f8a79d97dd42883302e13e708c1726e4
Complete billing of changes:
Related to Batch Inverse:
_batch_inverseinMultivariateNormal.matrix_powerMiscellaneous modifications: