Skip to content

Add cusolver to build, rewrite MAGMA inverse with cusolver#42403

Closed
xwang233 wants to merge 56 commits intomasterfrom
ci-all/cusolver-inverse
Closed

Add cusolver to build, rewrite MAGMA inverse with cusolver#42403
xwang233 wants to merge 56 commits intomasterfrom
ci-all/cusolver-inverse

Conversation

@xwang233
Copy link
Copy Markdown
Collaborator

@xwang233 xwang233 commented Aug 1, 2020

Fixes #42265

This PR adds cusolver to the pytorch build, and enables the use of cusolver/cublas library functions on GPU torch.inverse on certain tensor shapes.

Specifically, when

  • the tensor is two dimensional (single batch), or
  • has >2 dimensions (multiple batches) and batch_size <= 2, or
  • magma is not linked,

cusolver/cublas will be used. In other conditions, the current implementation of MAGMA will still be used.

Tensor _inverse_helper_cuda(const Tensor& self) {
#ifdef USE_CUSOLVER
if ((self.dim() == 2) || (/* self.dim() > 2 && */ batchCount(self) <= 2) || !use_magma_) {
return _inverse_helper_cuda_lib(self); // cusolver or cublas
} else {
return _inverse_helper_cuda_legacy(self); // magma-cuda
}
#else
return _inverse_helper_cuda_legacy(self); // magma-cuda
#endif
}

The reason for this is that for tensors with large batch_size, cublasXgetrfBatched and cublasXgetriBatched doesn't perform very well. For batch_size > 1, we launch cusolver functions in multiple streams. This lets cusolver functions run in parallel, and can greatly increase the performance. When batch_size > 2, the parallel launched cusolver functions are slightly slower than the current magma implementation, so we still use the current magma impl.

On CUDA 9.2, there were some numerical issues detected, so cusolver impl will not be used. The cusolver impl will also not be used on platforms other than Nvidia CUDA.

#if defined(CUDART_VERSION) && CUDART_VERSION >= 10000
// some cusolver functions doesn't work well on cuda 9.2, cusolver is used on cuda >= 10.0
#define USE_CUSOLVER
#endif

Note that there is a new heuristic used before cusolver/cublas calls here:

// heuristic:
// cublas_x_batched doesn't work very well for small batchsize
// cublas_x_batched is intended to be used for matrices of small sizes where the launch overhead is a significant factor.
// with use_loop_launch = True, we will loop through all batches, and launch single matrix cusolver/cublas kernels
// (This heuristic was originally tested in getrf + getrs(getri), which may not work well on other kernels. )
inline static bool use_loop_launch(int batch_size, int matrix_size) {
return (batch_size <= 8) || \
(/* batch_size > 8 && */ matrix_size >= 512);
}

where use_loop_launch = true means launch single batch cusolver functions in parallel, and use_loop_launch = false means use cublas_X_batched functions. When magma is enabled (only batch_size <= 2 will be dispatched to cusolver/cublas), the heuristic will always return true and the cusolver calls are faster than small batch_size magma calls. When magma is disabled, this adds the functionality of torch.inverse, which was disabled before for all shapes (though large batch_size cublas performance may not be as well as magma).

Checklist:

  • Add benchmark, cpu, gpu-before (magma), gpu-after (cusolver)
  • Rewrite single inverse (ndim == 2) with cusolver
  • Rewrite batched inverse (ndim > 2) with cublas
  • Add cusolver to build
  • Clean up functions related to USE_MAGMA define guard
  • Workaround for non-cuda platform
  • Workaround for cuda 9.2
  • Add zero size check
  • Add tests

Next step:

If cusolver doesn't cause any problem in pytorch build, and there are no major performance regressions reported after this PR being merged, I will start porting other cusolver/cublas functions for linear algebra to improve the performance.

benchmark 73499c6

benchmark code: https://github.com/xwang233/code-snippet/blob/master/linalg/inverse/inverse-cusolver.ipynb

shape meaning:

  • [] 2 torch.float32 -> torch.randn(2, 2, dtype=torch.float32)
  • [2] 4 torch.float32 -> torch.randn(2, 4, 4, dtype=torch.float32)
shape cpu_time (ms) gpu_time_before (magma) (ms) gpu_time_after (ms)
[] 2 torch.float32 0.095 7.534 0.129
[] 4 torch.float32 0.009 7.522 0.129
[] 8 torch.float32 0.011 7.647 0.138
[] 16 torch.float32 0.075 7.582 0.135
[] 32 torch.float32 0.073 7.573 0.191
[] 64 torch.float32 0.134 7.694 0.288
[] 128 torch.float32 0.398 8.073 0.491
[] 256 torch.float32 1.054 11.860 1.074
[] 512 torch.float32 5.218 14.130 2.582
[] 1024 torch.float32 19.010 18.780 6.936
[1] 2 torch.float32 0.009 0.113 0.128 ***regressed
[1] 4 torch.float32 0.009 0.113 0.131 ***regressed
[1] 8 torch.float32 0.011 0.116 0.129 ***regressed
[1] 16 torch.float32 0.015 0.122 0.135 ***regressed
[1] 32 torch.float32 0.032 0.177 0.178 ***regressed
[1] 64 torch.float32 0.070 0.420 0.281
[1] 128 torch.float32 0.328 0.816 0.490
[1] 256 torch.float32 1.125 1.690 1.084
[1] 512 torch.float32 4.344 4.305 2.576
[1] 1024 torch.float32 16.510 16.340 6.928
[2] 2 torch.float32 0.009 0.113 0.186 ***regressed
[2] 4 torch.float32 0.011 0.115 0.184 ***regressed
[2] 8 torch.float32 0.012 0.114 0.184 ***regressed
[2] 16 torch.float32 0.019 0.119 0.173 ***regressed
[2] 32 torch.float32 0.050 0.170 0.240 ***regressed
[2] 64 torch.float32 0.120 0.429 0.375
[2] 128 torch.float32 0.576 0.830 0.675
[2] 256 torch.float32 2.021 1.748 1.451
[2] 512 torch.float32 9.070 4.749 3.539
[2] 1024 torch.float32 33.655 18.240 12.220
[4] 2 torch.float32 0.009 0.112 0.318 ***regressed
[4] 4 torch.float32 0.010 0.115 0.319 ***regressed
[4] 8 torch.float32 0.013 0.115 0.320 ***regressed
[4] 16 torch.float32 0.027 0.120 0.331 ***regressed
[4] 32 torch.float32 0.085 0.173 0.385 ***regressed
[4] 64 torch.float32 0.221 0.431 0.646 ***regressed
[4] 128 torch.float32 1.102 0.834 1.055 ***regressed
[4] 256 torch.float32 4.042 1.811 2.054 ***regressed
[4] 512 torch.float32 18.390 4.884 5.087 ***regressed
[4] 1024 torch.float32 69.025 19.840 20.000 ***regressed

@dr-ci
Copy link
Copy Markdown

dr-ci Bot commented Aug 1, 2020

💊 CI failures summary and remediations

As of commit 1a0fe5a (more details on the Dr. CI page):


None of the CI failures appear to be your fault 💚



🚧 1 fixed upstream failure:

These were probably caused by upstream breakages that were already fixed.

Please rebase on the viable/strict branch (expand for instructions)

Since your merge base is older than viable/strict, run these commands:

git fetch https://github.com/pytorch/pytorch viable/strict
git rebase FETCH_HEAD

Check out the recency history of this "viable master" tracking branch.


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 178 times.

Copy link
Copy Markdown
Contributor

@vishwakftw vishwakftw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have some preliminary comments, I can take a pass at this PR once it's complete too.

Comment thread aten/src/ATen/native/cuda/BatchLinearAlgebra.cu Outdated
Comment thread aten/src/ATen/native/cuda/BatchLinearAlgebra.cu Outdated
Comment thread aten/src/ATen/native/cuda/BatchLinearAlgebra.cu Outdated
Comment thread aten/src/ATen/native/cuda/BatchLinearAlgebra.cu Outdated
Comment thread aten/src/ATen/native/cuda/BatchLinearAlgebra.cu Outdated
@vishwakftw vishwakftw self-requested a review August 1, 2020 20:38
Copy link
Copy Markdown
Collaborator

@IvanYashchuk IvanYashchuk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @xwang233, great start with rewriting MAGMA calls to cuSolver! I have some experience with cuSolver, so I thought that I drop a few comments.
One small issue that is often overlooked that the leading dimension is not necessarily the same as one of the dimensions of the matrix. The LDA parameter in BLAS and LAPACK is the stride of the matrix as it is laid out in the contiguous memory.
For example zero-sized matrix is perfectly valid input, but the current implementation will fail with CUSOLVER_STATUS_INVALID_VALUE. For getrs it means: invalid parameters were passed (n<0 or lda<max(1,n) or ldb<max(1,n)).

Comment thread aten/src/ATen/native/cuda/BatchLinearAlgebra.cu Outdated
Comment thread aten/src/ATen/native/cuda/BatchLinearAlgebra.cu Outdated
Comment thread aten/src/ATen/native/cuda/BatchLinearAlgebra.cu Outdated
Comment thread aten/src/ATen/native/cuda/BatchLinearAlgebra.cu Outdated
Copy link
Copy Markdown
Contributor

@vishwakftw vishwakftw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some more minor comments.

Comment thread aten/src/ATen/native/cuda/BatchLinearAlgebra.cu Outdated
Comment thread aten/src/ATen/native/cuda/BatchLinearAlgebra.cu Outdated
Comment thread aten/src/ATen/native/cuda/BatchLinearAlgebra.cu Outdated
Comment thread aten/src/ATen/native/cuda/BatchLinearAlgebra.cu Outdated
Comment thread aten/src/ATen/native/cuda/MiscUtils.h Outdated
Comment thread aten/src/ATen/native/cuda/BatchLinearAlgebra.cu Outdated
Comment thread aten/src/ATen/native/cuda/BatchLinearAlgebra.cu Outdated
Comment thread aten/src/ATen/native/cuda/BatchLinearAlgebra.cu Outdated
Comment thread aten/src/ATen/native/cuda/BatchLinearAlgebra.cu Outdated
Comment thread aten/src/ATen/native/cuda/BatchLinearAlgebra.cu Outdated
@vishwakftw
Copy link
Copy Markdown
Contributor

@xwang233 If you can generate some benchmark numbers to support this transition, that would be great.

@xwang233
Copy link
Copy Markdown
Collaborator Author

xwang233 commented Aug 8, 2020

@vishwakftw Thanks for the review! I'll get benchmarks for cpu, gpu-before (magma), gpu-after (cusolver) soon when this PR is ready.

@xwang233
Copy link
Copy Markdown
Collaborator Author

xwang233 commented Aug 8, 2020

Ok, I see where the problem is. To avoid confusions, we usually name our functions differently from library functions.

magma cublas
library magma_sgetri_batched cublasSgetriBatched
pytorch code magmaGetriBatched (current) cublas_getri_batched (this PR)

Also, the current code sometimes use magmaGetriBatched to wrap magma_Xgetri_batched, but sometimes use magmaLuBatched to wrap magma_Xgetrf_batched, which seems pretty confusing to me. I would suggest that we find a unified function naming scheme, and keep using that in the future code

  • Use getrf or Lu; potrs or Cholesky solve
  • Use snake_case or camelCase

cc @ngimel @ptrblck

Copy link
Copy Markdown
Collaborator

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a small question on the test, otherwise it is good to go. Thanks, great work!

Comment thread test/test_torch.py Outdated
matrix_inverse_out = torch.empty(*batches, n, n, dtype=torch.float64, device=device)
torch.inverse(matrix, out=matrix_inverse_out)
self.assertEqual(matrix_inverse_out, matrix_inverse, atol=0, rtol=0)
# second call, now that matrix_inverse_out is transposed
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see matrix_inverse_out being transposed, at least prior to this PR. Do we need this second case?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you are right. I tried that and the out=matrix_inverse_out doesn't seem to be transposed, if matrix_inverse_out wasn't transposed before. I'll modify that and remove the second run.

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@mruberry mruberry mentioned this pull request Sep 15, 2020
Comment thread aten/src/ATen/cuda/CUDABlas.cpp Outdated

template <>
void getrfBatched<double>(
int _m, int n, double** dA_array, int ldda, int* ipiv_array, int* info_array, int batchsize) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, did not notice this earlier, but linter complains about unused param _m here (CUDABLAS_GETRF_ARGTYPES probably also need to be adjusted?)

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@ngimel merged this pull request in d75c402.

facebook-github-bot pushed a commit that referenced this pull request Nov 18, 2020
…-stream issue (#47026)

Summary:
### test_inverse_singular for cublas failure

Related
#46616 (comment)
https://app.circleci.com/pipelines/github/pytorch/pytorch/232112/workflows/4131d4ca-cd51-44e3-8e6c-b1c3555c62fa/jobs/8523970/tests

The cuda 11.1 CI container doesn't have MAGMA library, so cublas matrix inverse path is enabled.
```
Oct 27 23:13:47 -- MAGMA not found. Compiling without MAGMA support
```

The test_inverse_singular was introduced in #46625, but I forgot to fix that functionality for cublas path as well.

### cusolver inverse multi-stream failure

fix #47272

The original cuda event record/block stream was wrong, which could cause NaN in output tensor.

On my machine, the original code observes NaN in about 50k~500k loops. After this change, no NaN is observed in more than 2.5m loops.

The performance for batch 2 matrix inverse is still the same as those in #42403.

Pull Request resolved: #47026

Reviewed By: mruberry

Differential Revision: D24838546

Pulled By: ngimel

fbshipit-source-id: 3b83e4ab8e6b47a8273cba277251765bd6d97911
@facebook-github-bot facebook-github-bot deleted the ci-all/cusolver-inverse branch January 27, 2021 18:26
xsacha pushed a commit to xsacha/pytorch that referenced this pull request Mar 31, 2021
xsacha pushed a commit to xsacha/pytorch that referenced this pull request Mar 31, 2021
xsacha pushed a commit to xsacha/pytorch that referenced this pull request Jun 21, 2021
emcastillo pushed a commit to emcastillo/pytorch that referenced this pull request Mar 16, 2022
…-stream issue (pytorch#47026)

Summary:
### test_inverse_singular for cublas failure

Related
pytorch#46616 (comment)
https://app.circleci.com/pipelines/github/pytorch/pytorch/232112/workflows/4131d4ca-cd51-44e3-8e6c-b1c3555c62fa/jobs/8523970/tests

The cuda 11.1 CI container doesn't have MAGMA library, so cublas matrix inverse path is enabled.
```
Oct 27 23:13:47 -- MAGMA not found. Compiling without MAGMA support
```

The test_inverse_singular was introduced in pytorch#46625, but I forgot to fix that functionality for cublas path as well.

### cusolver inverse multi-stream failure

fix pytorch#47272

The original cuda event record/block stream was wrong, which could cause NaN in output tensor.

On my machine, the original code observes NaN in about 50k~500k loops. After this change, no NaN is observed in more than 2.5m loops.

The performance for batch 2 matrix inverse is still the same as those in pytorch#42403.

Pull Request resolved: pytorch#47026

Reviewed By: mruberry

Differential Revision: D24838546

Pulled By: ngimel

fbshipit-source-id: 3b83e4ab8e6b47a8273cba277251765bd6d97911
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
…2403)

Summary:
Fixes pytorch#42265

This PR adds cusolver to the pytorch build, and enables the use of cusolver/cublas library functions on GPU `torch.inverse` on certain tensor shapes.

Specifically, when

* the tensor is two dimensional (single batch), or
* has >2 dimensions (multiple batches) and `batch_size <= 2`, or
* magma is not linked,

cusolver/cublas will be used. In other conditions, the current implementation of MAGMA will still be used.

https://github.com/pytorch/pytorch/blob/8c0949ae454b1d2c1b626a5ea19ba5ea6487d305/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu#L742-L752

The reason for this is that for tensors with large batch_size, `cublasXgetrfBatched` and `cublasXgetriBatched` doesn't perform very well. For `batch_size > 1`, we launch cusolver functions in multiple streams. This lets cusolver functions run in parallel, and can greatly increase the performance. When `batch_size > 2`, the parallel launched cusolver functions are slightly slower than the current magma implementation, so we still use the current magma impl.

On CUDA 9.2, there were some numerical issues detected, so cusolver impl will not be used. The cusolver impl will also not be used on platforms other than Nvidia CUDA.

https://github.com/pytorch/pytorch/blob/060769feaf02db56ac79e0c73dab1105828ece69/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.h#L10-L13

Note that there is a new heuristic used before cusolver/cublas calls here:

https://github.com/pytorch/pytorch/blob/8c0949ae454b1d2c1b626a5ea19ba5ea6487d305/aten/src/ATen/native/cuda/MiscUtils.h#L113-L121

where `use_loop_launch = true` means launch single batch cusolver functions in parallel, and `use_loop_launch = false` means use cublas_X_batched functions. When magma is enabled (only `batch_size <= 2` will be dispatched to cusolver/cublas), the heuristic will always return `true` and the cusolver calls are faster than small batch_size magma calls. When magma is disabled, this adds the functionality of `torch.inverse`, which was disabled before for all shapes (though large batch_size cublas performance may not be as well as magma).

Checklist:
- [X] Add benchmark, cpu, gpu-before (magma), gpu-after (cusolver)
- [X] Rewrite single inverse (ndim == 2) with cusolver
- [X] Rewrite batched inverse (ndim > 2) with cublas
- [X] Add cusolver to build
- [x] Clean up functions related to `USE_MAGMA` define guard
- [x] Workaround for non-cuda platform
- [x] Workaround for cuda 9.2
- [x] Add zero size check
- [x] Add tests

Next step:

If cusolver doesn't cause any problem in pytorch build, and there are no major performance regressions reported after this PR being merged, I will start porting other cusolver/cublas functions for linear algebra to improve the performance.

<details>
<summary> benchmark 73499c6 </summary>

benchmark code: https://github.com/xwang233/code-snippet/blob/master/torch.inverse/inverse-cusolver.ipynb

shape meaning:

* `[] 2 torch.float32 -> torch.randn(2, 2, dtype=torch.float32)`
* `[2] 4 torch.float32 -> torch.randn(2, 4, 4, dtype=torch.float32)`

| shape | cpu_time (ms) | gpu_time_before (magma) (ms) | gpu_time_after (ms) |
| --- | --- | --- | --- |
| [] 2 torch.float32 |  0.095 |  7.534 |  0.129  |
| [] 4 torch.float32 |  0.009 |  7.522 |  0.129  |
| [] 8 torch.float32 |  0.011 |  7.647 |  0.138  |
| [] 16 torch.float32 |  0.075 |  7.582 |  0.135  |
| [] 32 torch.float32 |  0.073 |  7.573 |  0.191  |
| [] 64 torch.float32 |  0.134 |  7.694 |  0.288  |
| [] 128 torch.float32 |  0.398 |  8.073 |  0.491  |
| [] 256 torch.float32 |  1.054 |  11.860 |  1.074  |
| [] 512 torch.float32 |  5.218 |  14.130 |  2.582  |
| [] 1024 torch.float32 |  19.010 |  18.780 |  6.936  |
| [1] 2 torch.float32 |  0.009 |  0.113 |  0.128 ***regressed |
| [1] 4 torch.float32 |  0.009 |  0.113 |  0.131 ***regressed |
| [1] 8 torch.float32 |  0.011 |  0.116 |  0.129 ***regressed |
| [1] 16 torch.float32 |  0.015 |  0.122 |  0.135 ***regressed |
| [1] 32 torch.float32 |  0.032 |  0.177 |  0.178 ***regressed |
| [1] 64 torch.float32 |  0.070 |  0.420 |  0.281  |
| [1] 128 torch.float32 |  0.328 |  0.816 |  0.490  |
| [1] 256 torch.float32 |  1.125 |  1.690 |  1.084  |
| [1] 512 torch.float32 |  4.344 |  4.305 |  2.576  |
| [1] 1024 torch.float32 |  16.510 |  16.340 |  6.928  |
| [2] 2 torch.float32 |  0.009 |  0.113 |  0.186 ***regressed |
| [2] 4 torch.float32 |  0.011 |  0.115 |  0.184 ***regressed |
| [2] 8 torch.float32 |  0.012 |  0.114 |  0.184 ***regressed |
| [2] 16 torch.float32 |  0.019 |  0.119 |  0.173 ***regressed |
| [2] 32 torch.float32 |  0.050 |  0.170 |  0.240 ***regressed |
| [2] 64 torch.float32 |  0.120 |  0.429 |  0.375  |
| [2] 128 torch.float32 |  0.576 |  0.830 |  0.675  |
| [2] 256 torch.float32 |  2.021 |  1.748 |  1.451  |
| [2] 512 torch.float32 |  9.070 |  4.749 |  3.539  |
| [2] 1024 torch.float32 |  33.655 |  18.240 |  12.220  |
| [4] 2 torch.float32 |  0.009 |  0.112 |  0.318 ***regressed |
| [4] 4 torch.float32 |  0.010 |  0.115 |  0.319 ***regressed |
| [4] 8 torch.float32 |  0.013 |  0.115 |  0.320 ***regressed |
| [4] 16 torch.float32 |  0.027 |  0.120 |  0.331 ***regressed |
| [4] 32 torch.float32 |  0.085 |  0.173 |  0.385 ***regressed |
| [4] 64 torch.float32 |  0.221 |  0.431 |  0.646 ***regressed |
| [4] 128 torch.float32 |  1.102 |  0.834 |  1.055 ***regressed |
| [4] 256 torch.float32 |  4.042 |  1.811 |  2.054 ***regressed |
| [4] 512 torch.float32 |  18.390 |  4.884 |  5.087 ***regressed |
| [4] 1024 torch.float32 |  69.025 |  19.840 |  20.000 ***regressed |

</details>

Pull Request resolved: pytorch#42403

Reviewed By: ailzhang, mruberry

Differential Revision: D23717984

Pulled By: ngimel

fbshipit-source-id: 54cbd9ea72a97989cff4127089938e8a8e29a72b
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
…-stream issue (pytorch#47026)

Summary:
### test_inverse_singular for cublas failure

Related
pytorch#46616 (comment)
https://app.circleci.com/pipelines/github/pytorch/pytorch/232112/workflows/4131d4ca-cd51-44e3-8e6c-b1c3555c62fa/jobs/8523970/tests

The cuda 11.1 CI container doesn't have MAGMA library, so cublas matrix inverse path is enabled.
```
Oct 27 23:13:47 -- MAGMA not found. Compiling without MAGMA support
```

The test_inverse_singular was introduced in pytorch#46625, but I forgot to fix that functionality for cublas path as well.

### cusolver inverse multi-stream failure

fix pytorch#47272

The original cuda event record/block stream was wrong, which could cause NaN in output tensor.

On my machine, the original code observes NaN in about 50k~500k loops. After this change, no NaN is observed in more than 2.5m loops.

The performance for batch 2 matrix inverse is still the same as those in pytorch#42403.

Pull Request resolved: pytorch#47026

Reviewed By: mruberry

Differential Revision: D24838546

Pulled By: ngimel

fbshipit-source-id: 3b83e4ab8e6b47a8273cba277251765bd6d97911
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: build Build system issues module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

torch.inverse() performing very poorly on GPU vs CPU

10 participants