Skip to content

torch.linalg.eigh: performance cliff at matrix size n=32 for batched inputs on CUDAΒ #175585

@sjdu10

Description

@sjdu10

πŸ› Describe the bug

torch.linalg.eigh suffers a catastrophic performance cliff when the matrix size crosses n=32 β†’ 33 for batched inputs on CUDA. I find the root cause to be a probably outdated matrix size gate in PyTorch's dispatch logic that only routes to the batched cuSOLVER API (cusolverDnXsyevBatched) for n ≀ 32, falling back to a sequential for-loop over individual matrices for n > 32.

Minimal reproducer

import torch
torch.manual_seed(42)
device = "cuda"
B = 128
dtype = torch.float64
print(f'torch version: {torch.__version__}')
print(f'B={B}, dtype={dtype}, matrix size n')
w = 8
print(f"  {'n':>{w}} | {'eigh (ms)':>{w}} ")
print("-" * (2 * w + 9))
for n in [1,2,4,8,16,32,33,34,64,96,128]:
    x = torch.randn(B, n, n, device=device, dtype=dtype)
    a = (x + x.mT) / 2
    # warmup
    for _ in range(1):
        torch.linalg.eigh(a)
    torch.cuda.synchronize()

    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    for _ in range(5):
        eigvals, eigvecs = torch.linalg.eigh(a)
    end.record()
    torch.cuda.synchronize()
    t = start.elapsed_time(end) / 5
    marker = " <-- n=32 threshold" if n == 32 else ""
    print(f"  {n:{w}d} | {t:{w}.2f} ms{marker}")

Output (PyTorch 2.10.0+cu126, RTX 4080):

torch version: 2.10.0+cu126
B=128, dtype=torch.float64, matrix size n
         n | eigh (ms)
-------------------------
         1 |     0.12 ms
         2 |     0.33 ms
         4 |     0.20 ms
         8 |     0.30 ms
        16 |     0.62 ms
        32 |     1.65 ms <-- n=32 threshold
        33 |   137.13 ms
        34 |   139.70 ms
        64 |   228.77 ms
        96 |   457.44 ms
       128 |   580.13 ms

n=32 β†’ n=33 jumps from 1.65 ms to 137 ms β€” an ~83x cliff.

Root cause

The dispatch lives in aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.cpp, function linalg_eigh_cusolver (~line 1614):

void linalg_eigh_cusolver(const Tensor& eigenvalues, const Tensor& eigenvectors,
                          const Tensor& infos, bool upper, bool compute_eigenvectors) {
#if defined(USE_ROCM)
  ...
#else
  if (batchCount(eigenvectors) > 1 && eigenvectors.size(-1) <= 32) {
    // Use syevjBatched for batched matrix operation when matrix size <= 32
    linalg_eigh_cusolver_syevj_batched(eigenvalues, eigenvectors, infos, upper, compute_eigenvectors);
  } else if (eigenvectors.scalar_type() == at::kFloat &&
             eigenvectors.size(-1) >= 32 && eigenvectors.size(-1) <= 512) {
    linalg_eigh_cusolver_syevj(eigenvalues, eigenvectors, infos, upper, compute_eigenvectors);
  } else {
    linalg_eigh_cusolver_syevd(eigenvalues, eigenvectors, infos, upper, compute_eigenvectors);
  }
#endif
}

The <= 32 gate is a leftover from when the old syevjBatched Jacobi API performed poorly for larger matrices (see PR #53040). PR #155695 (June 2025) replaced the underlying call with the newer cusolverDnXsyevBatched which in my test performs much better for larger matrix now, but the top-level dispatch condition was never updated.

Proposed fix

Remove the <= 32 size gate so all batched eigh calls route through cusolverDnXsyevBatched:

  if (batchCount(eigenvectors) > 1) {
    // cusolverDnXsyevBatched works efficiently for all n when batch > 1
    linalg_eigh_cusolver_syevj_batched(eigenvalues, eigenvectors, infos, upper, compute_eigenvectors);
  } else if ...

Benchmark results

I verified the fix by building PyTorch from source with only this one-line change. All benchmarks on NVIDIA GeForce RTX 4080, CUDA 12.8, Ubuntu 22.04.5 LTS (WSL2). Timing via CUDA events (3 warmup, 5 timed iterations).

Main result: B=128, float64

n default (ms) fixed (ms) speedup
8 0.32 0.31 1.0x
16 0.66 0.63 1.1x
24 1.22 1.19 1.0x
28 1.44 1.49 1.0x
30 1.56 1.66 0.9x
31 1.77 1.64 1.1x
32 1.54 1.59 1.0x
33 137.7 3.04 45x
34 145.8 2.63 55x
40 160.9 3.19 50x
48 180.6 3.47 52x
56 204.4 4.12 50x
64 241.4 4.67 52x
80 387.9 10.5 37x
96 458.0 12.9 36x
128 605.1 18.9 32x
256 714.2 68.3 10x
400 1113.7 198.9 6x
512 1382.3 322.5 4x
513 1455.7 391.0 4x
600 1793.8 547.6 3x
1024 3555.8 1962.0 1.8x

Consistent speedup for n > 32, no regression for n ≀ 32.

Environment

  • GPU: NVIDIA GeForce RTX 4080
  • CUDA: 12.6 / 12.8
  • PyTorch: 2.100+cu126 (default) / 2.12.0a0+git67c428c (with fix)
  • OS: Ubuntu 22.04.5 LTS (WSL2)

Related

Versions

PyTorch version: 2.10.0+cu126
Is debug build: False
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04.2) 11.4.0
Clang version: Could not collect
CMake version: version 4.1.0
Libc version: glibc-2.35

Python version: 3.12.9 (main, Feb 12 2025, 14:50:50) [Clang 19.1.6 ] (64-bit runtime)
Python platform: Linux-6.6.87.2-microsoft-standard-WSL2-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.8.61
CUDA_MODULE_LOADING set to:
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 4080
Nvidia driver version: 591.74
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 20
On-line CPU(s) list: 0-19
Vendor ID: GenuineIntel
Model name: 12th Gen Intel(R) Core(TM) i7-12700K
CPU family: 6
Model: 151
Thread(s) per core: 2
Core(s) per socket: 10
Socket(s): 1
Stepping: 2
BogoMIPS: 7219.19
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology tsc_reliable nonstop_tsc cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves avx_vnni vnmi umip waitpkg gfni vaes vpclmulqdq rdpid movdiri movdir64b fsrm md_clear serialize flush_l1d arch_capabilities
Virtualization: VT-x
Hypervisor vendor: Microsoft
Virtualization type: full
L1d cache: 480 KiB (10 instances)
L1i cache: 320 KiB (10 instances)
L2 cache: 12.5 MiB (10 instances)
L3 cache: 25 MiB (1 instance)
NUMA node(s): 1
NUMA node0 CPU(s): 0-19
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Mitigation; Clear Register File
Vulnerability Retbleed: Mitigation; Enhanced IBRS
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI BHI_DIS_S
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected

Versions of relevant libraries:
[pip3] Could not collect
[conda] Could not collect

cc @jerryzh168 @ptrblck @msaroufim @eqy @tinglvv @nWEIdia @jianyuh @nikitaved @mruberry @walterddr @xwang233 @lezcano

Metadata

Metadata

Assignees

No one assigned

    Labels

    bot-triagedThis is a label only to be used by the auto triage botmodule: cudaRelated to torch.cuda, and CUDA support in generalmodule: linear algebraIssues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmulmodule: performanceIssues related to performance, either of kernel code or framework gluetriage review

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions