-
Notifications
You must be signed in to change notification settings - Fork 27.7k
Linear algebra GPU backend tracking issue [magma/cusolver/cublas] #47953
Copy link
Copy link
Open
Labels
high prioritymodule: cudaRelated to torch.cuda, and CUDA support in generalRelated to torch.cuda, and CUDA support in generalmodule: linear algebraIssues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmulIssues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmulmodule: performanceIssues related to performance, either of kernel code or framework glueIssues related to performance, either of kernel code or framework gluetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Metadata
Metadata
Assignees
Labels
high prioritymodule: cudaRelated to torch.cuda, and CUDA support in generalRelated to torch.cuda, and CUDA support in generalmodule: linear algebraIssues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmulIssues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmulmodule: performanceIssues related to performance, either of kernel code or framework glueIssues related to performance, either of kernel code or framework gluetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Linear algebra GPU backend tracking issue [MAGMA/cuSOLVER/cuBLAS]
Currently, most GPU linear algebra operators are using MAGMA for their backends, with only a few using cuSOLVER/cuBLAS instead. To improve performance, we would like to migrate the backend of bad-performing MAGMA linear algebra operators to cuSOLVER/cuBLAS backends if they perform better.
This issue is used to track which linear algebra operators currently do not use MAGMA as their GPU backend by default, and also track a list of known bad-performing MAGMA operators that could benefit from cuSOLVER/cuBLAS. Feel free to modify this list and link to this issue if you are aware of any such operators.
We welcome contributions to add cuSOLVER/cuBLAS backends for bad-performing MAGMA operators. Please make sure you add benchmark for your PR, and add heuristics that dispatch the operator to different backends if necessary.
(This issue doesn't track CPU or other backends.)
CUDA version requirement for cuSOLVER/cuBLAS
cuSOLVER/cuBLAS is only enabled when CUDA version is >= 10.1.243 [#45452]. There is no limitation on GPU architectures.
If your CUDA version is lower than that, everything will be dispatched to MAGMA. If MAGMA is not linked in your build, you will get runtime error while calling these linear algebra operators on GPU.
Operators that currently use non-MAGMA backends
For simplicity, we use
bfor batch size,m,nfor matrix size. A two-dimensional tensor is considered a batch size 1 matrix. Without explicit exceptions,b == 1cases include both 2d tensor and >=3d tensor with batch dimension == 1.Also, most
torch.linalg.xshares the same backend astorch.xlinear algebra operator by default.torch.inverse,torch.linalg.inv_exb <= 2torch.svdif (m <= 32 && n <= 32 && b > 1 && ( !some || m == n )) gesvdjBatched; else gesvdj;torch.cholesky,torch.linalg.cholesky_exb > 1uses cusolver only when cuda >= 11.3torch.cholesky_solveb == 1torch.cholesky_inverseb == 1cholesky_solveas the backend.torch.orgqrtorch.ormqrtorch.geqrfif (n <= 256 && b >= max(2, n / 16)) cublas_batched; else cusolver_loopedtorch.linalg.qrgeqrf+orgqras the backend.torch.linalg.eightorch.lu_solve(b == 1 && n > 512) || (b > 2 && n <= 128)bandnare tensor sizes ofLU_dataor matrix "A".torch.lstsqgeqrf,ormqr, andtriangular_solve.last updated fe4ded0, June 29th, 2021
Pytorch 1.9 linear algebra development plan
See #47953 (comment)
For detailed MAGMA mechanism
See #47953 (comment)
See also
cc @ezyang @gchanan @zou3519 @bdhirsh @ngimel @vishwakftw @jianyuh @nikitaved @pearu @mruberry @heitorschueroff @walterddr @VitalyFedyunin @ptrblck @IvanYashchuk