Skip to content

torch.linalg.eigh: very slow for batched inputs #174674

@nikitaved

Description

@nikitaved

🐛 Describe the bug

First observed in #174601 (discovered by @alexshtf)

eigh is much slower on PyTorch compared to CuPy.

cmp.py:

import torch
import cupy

def get_max_error(inp, eigvecs, eigvals):
    return (inp @ eigvecs - eigvals.unsqueeze(-2) * eigvecs).abs().max()

q, _ = torch.linalg.qr(torch.randn(500, 100, 100, device="cuda"))
inp_torch = q + q.mH
inp_cupy = cupy.array(inp_torch.cpu().numpy())

eigvals_t, eigvecs_t = torch.linalg.eigh(inp_torch)
print(f"Torch error: {get_max_error(inp_torch, eigvecs_t, eigvals_t)}")

eigvals_c, eigvecs_c = cupy.linalg.eigh(inp_cupy)
print(f"CuPy error: {get_max_error(torch.as_tensor(inp_cupy), torch.as_tensor(eigvecs_c), torch.as_tensor(eigvals_c))}")

When we run it:

ipython -i cmp.py 
...
Torch error: 2.8014183044433594e-05
CuPy error: 1.7702579498291016e-05

In [1]: %timeit torch.linalg.eigh(inp_torch); torch.cuda.synchronize()
914 ms ± 925 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [2]: %timeit cupy.linalg.eigh(inp_cupy); torch.cuda.synchronize()
8.36 ms ± 1.83 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)

We can see that CuPy is much faster. It seems we should re-consider our heuristics for choosing suitable cuSOLVER drivers.

Versions

PyTorch version: 2.11.0a0+gitc84efa2                                             
Is debug build: False                                                            
CUDA used to build PyTorch: 12.9                                                                                                 
                                                                                                                 
Python version: 3.10.19 | packaged by conda-forge | (main, Jan 26 2026, 23:45:08) [GCC 14.3.0] (64-bit runtime)
                                                                                                  
Is CUDA available: True                                                                                                                                            
CUDA runtime version: 12.9.86                                                                                                                                      
CUDA_MODULE_LOADING set to:                                                                                                                                        
GPU models and configuration:                                                                                                                                      
GPU 0: NVIDIA H100 80GB HBM3                                                                                                                                                                                        
                                                                                 
Nvidia driver version: 580.105.08

CuPy version: 13.6.0 installed with pip                                                             

cc @ptrblck @msaroufim @eqy @jerryzh168 @tinglvv @nWEIdia @jianyuh @mruberry @walterddr @xwang233 @lezcano

Metadata

Metadata

Assignees

No one assigned

    Labels

    bot-triagedThis is a label only to be used by the auto triage botmodule: cudaRelated to torch.cuda, and CUDA support in generalmodule: linear algebraIssues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmultopic: performancetopic categorytriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions