-
Notifications
You must be signed in to change notification settings - Fork 28k
Direct inversion and linear systems solutions for small matrices #63992
Copy link
Copy link
Open
Labels
module: cudaRelated to torch.cuda, and CUDA support in generalRelated to torch.cuda, and CUDA support in generalmodule: linear algebraIssues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmulIssues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmulmodule: performanceIssues related to performance, either of kernel code or framework glueIssues related to performance, either of kernel code or framework gluetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Metadata
Metadata
Assignees
Labels
module: cudaRelated to torch.cuda, and CUDA support in generalRelated to torch.cuda, and CUDA support in generalmodule: linear algebraIssues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmulIssues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmulmodule: performanceIssues related to performance, either of kernel code or framework glueIssues related to performance, either of kernel code or framework gluetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Type
Fields
Give feedbackNo fields configured for issues without a type.
🚀 Feature
Performance improvements to linalg in GPUs
Motivation
After noticing that one of my functions was running unusually slow and timing line per line I noticed that a batched inversion of 2 x 2 matrices was hundreds of time slower than a 8k-point FFT on GPU. Experiments show that the a 2 x 2 inversion formula implemented in python may be hundreds of time faster on GPU than the linal implementatinos for inv and solve, of pytorch 1.9.0.
Pitch
The linalg module is an active development area in pytorch, it may have many opportunities for improvements, and it seems I identified one important.
Alternatives
Probably implementing this feature in a wrapper as I did is not a good idea since it would probably affect the performance for small batches.
Additional context
I prepared a notebook that can be visulised here, where I describe the implementation with the analytic inversion formula and run the experiments, so that you can review my methodology.
Just to give an ide, here I show the performance of a batched solve, in number of 2x2 matrix inversions per microseconds.
The proposed implementation is not very appealing for CPU but it gives massive gains for the GPU backend. But, the only scenario where the proposed implementation is considerably worse than the linalg implementation was the CPU with back propagation for complex128, I don't believe this is a very important case, since in general when training GPUs are available.
The proposed method gains with a significant margin, this suggests that there may be gains for larger systems as well, but I did not test.
cc @VitalyFedyunin @ngimel @heitorschueroff @jianyuh @nikitaved @pearu @mruberry @walterddr @IvanYashchuk @xwang233 @lezcano