Skip to content

SVD is slow on GPU vs CPU for skinny matrices #41306

@n-gao

Description

@n-gao

🐛 Bug

Performing SVD on the GPU is extremely slow and as far as I it is an open research quetion whether SVD in general can gain a lot by performing it on the GPU. So I want to propose to execute it on the CPU by default.

To Reproduce

CPU:

A = torch.randn(100, 10, 10).cpu()
%timeit torch.svd(A); torch.cuda.synchronize()

Result:

2.02 ms ± 71.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

CUDA:

A = torch.randn(100, 10, 10).cuda()
%timeit torch.svd(A); torch.cuda.synchronize()

Result:

143 ms ± 4.02 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

Here is a figure illustrating the behavior depending on the batch_size:
image

Expected behavior

The CUDA result should be at least as good as the CPU otherwise there is no point in using CUDA.

Environment

  • PyTorch Version (e.g., 1.0): 1.5
  • OS (e.g., Linux): Windows 10 2004
  • How you installed PyTorch (conda, pip, source): pip
  • Build command you used (if compiling from source):
  • Python version: 3.7.4
  • CUDA/cuDNN version: 11.0
  • GPU models and configuration: Nvidia GTX 1050
  • Any other relevant information:

Additional context

I am aware that I could manually move the tensor to the CPU before computing the SVD, but this has several drawbacks:

  1. It is unintuitive that the CUDA version of SVD is worse than the CPU version
  2. Functions like torch.slogdet use SVD internally during the backward pass if the tensor is singular. This can only be avoided by implementing custom gradients or by performing slogdet on the CPU despite slogdet being faster with CUDA.

cc @vincentqb @vishwakftw @ssnl @jianyuh @VitalyFedyunin @ngimel

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: linear algebraIssues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmulmodule: performanceIssues related to performance, either of kernel code or framework gluetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions