π Describe the bug
torch.linalg.eigh suffers a catastrophic performance cliff when the matrix size crosses n=32 β 33 for batched inputs on CUDA. I find the root cause to be a probably outdated matrix size gate in PyTorch's dispatch logic that only routes to the batched cuSOLVER API (cusolverDnXsyevBatched) for n β€ 32, falling back to a sequential for-loop over individual matrices for n > 32.
Minimal reproducer
import torch
torch.manual_seed(42)
device = "cuda"
B = 128
dtype = torch.float64
print(f'torch version: {torch.__version__}')
print(f'B={B}, dtype={dtype}, matrix size n')
w = 8
print(f" {'n':>{w}} | {'eigh (ms)':>{w}} ")
print("-" * (2 * w + 9))
for n in [1,2,4,8,16,32,33,34,64,96,128]:
x = torch.randn(B, n, n, device=device, dtype=dtype)
a = (x + x.mT) / 2
# warmup
for _ in range(1):
torch.linalg.eigh(a)
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(5):
eigvals, eigvecs = torch.linalg.eigh(a)
end.record()
torch.cuda.synchronize()
t = start.elapsed_time(end) / 5
marker = " <-- n=32 threshold" if n == 32 else ""
print(f" {n:{w}d} | {t:{w}.2f} ms{marker}")
Output (PyTorch 2.10.0+cu126, RTX 4080):
torch version: 2.10.0+cu126
B=128, dtype=torch.float64, matrix size n
n | eigh (ms)
-------------------------
1 | 0.12 ms
2 | 0.33 ms
4 | 0.20 ms
8 | 0.30 ms
16 | 0.62 ms
32 | 1.65 ms <-- n=32 threshold
33 | 137.13 ms
34 | 139.70 ms
64 | 228.77 ms
96 | 457.44 ms
128 | 580.13 ms
n=32 β n=33 jumps from 1.65 ms to 137 ms β an ~83x cliff.
Root cause
The dispatch lives in aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.cpp, function linalg_eigh_cusolver (~line 1614):
void linalg_eigh_cusolver(const Tensor& eigenvalues, const Tensor& eigenvectors,
const Tensor& infos, bool upper, bool compute_eigenvectors) {
#if defined(USE_ROCM)
...
#else
if (batchCount(eigenvectors) > 1 && eigenvectors.size(-1) <= 32) {
// Use syevjBatched for batched matrix operation when matrix size <= 32
linalg_eigh_cusolver_syevj_batched(eigenvalues, eigenvectors, infos, upper, compute_eigenvectors);
} else if (eigenvectors.scalar_type() == at::kFloat &&
eigenvectors.size(-1) >= 32 && eigenvectors.size(-1) <= 512) {
linalg_eigh_cusolver_syevj(eigenvalues, eigenvectors, infos, upper, compute_eigenvectors);
} else {
linalg_eigh_cusolver_syevd(eigenvalues, eigenvectors, infos, upper, compute_eigenvectors);
}
#endif
}
The <= 32 gate is a leftover from when the old syevjBatched Jacobi API performed poorly for larger matrices (see PR #53040). PR #155695 (June 2025) replaced the underlying call with the newer cusolverDnXsyevBatched which in my test performs much better for larger matrix now, but the top-level dispatch condition was never updated.
Proposed fix
Remove the <= 32 size gate so all batched eigh calls route through cusolverDnXsyevBatched:
if (batchCount(eigenvectors) > 1) {
// cusolverDnXsyevBatched works efficiently for all n when batch > 1
linalg_eigh_cusolver_syevj_batched(eigenvalues, eigenvectors, infos, upper, compute_eigenvectors);
} else if ...
Benchmark results
I verified the fix by building PyTorch from source with only this one-line change. All benchmarks on NVIDIA GeForce RTX 4080, CUDA 12.8, Ubuntu 22.04.5 LTS (WSL2). Timing via CUDA events (3 warmup, 5 timed iterations).
Main result: B=128, float64
| n |
default (ms) |
fixed (ms) |
speedup |
| 8 |
0.32 |
0.31 |
1.0x |
| 16 |
0.66 |
0.63 |
1.1x |
| 24 |
1.22 |
1.19 |
1.0x |
| 28 |
1.44 |
1.49 |
1.0x |
| 30 |
1.56 |
1.66 |
0.9x |
| 31 |
1.77 |
1.64 |
1.1x |
| 32 |
1.54 |
1.59 |
1.0x |
| 33 |
137.7 |
3.04 |
45x |
| 34 |
145.8 |
2.63 |
55x |
| 40 |
160.9 |
3.19 |
50x |
| 48 |
180.6 |
3.47 |
52x |
| 56 |
204.4 |
4.12 |
50x |
| 64 |
241.4 |
4.67 |
52x |
| 80 |
387.9 |
10.5 |
37x |
| 96 |
458.0 |
12.9 |
36x |
| 128 |
605.1 |
18.9 |
32x |
| 256 |
714.2 |
68.3 |
10x |
| 400 |
1113.7 |
198.9 |
6x |
| 512 |
1382.3 |
322.5 |
4x |
| 513 |
1455.7 |
391.0 |
4x |
| 600 |
1793.8 |
547.6 |
3x |
| 1024 |
3555.8 |
1962.0 |
1.8x |
Consistent speedup for n > 32, no regression for n β€ 32.
Environment
- GPU: NVIDIA GeForce RTX 4080
- CUDA: 12.6 / 12.8
- PyTorch: 2.100+cu126 (default) / 2.12.0a0+git67c428c (with fix)
- OS: Ubuntu 22.04.5 LTS (WSL2)
Related
Versions
PyTorch version: 2.10.0+cu126
Is debug build: False
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04.2) 11.4.0
Clang version: Could not collect
CMake version: version 4.1.0
Libc version: glibc-2.35
Python version: 3.12.9 (main, Feb 12 2025, 14:50:50) [Clang 19.1.6 ] (64-bit runtime)
Python platform: Linux-6.6.87.2-microsoft-standard-WSL2-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.8.61
CUDA_MODULE_LOADING set to:
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 4080
Nvidia driver version: 591.74
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 20
On-line CPU(s) list: 0-19
Vendor ID: GenuineIntel
Model name: 12th Gen Intel(R) Core(TM) i7-12700K
CPU family: 6
Model: 151
Thread(s) per core: 2
Core(s) per socket: 10
Socket(s): 1
Stepping: 2
BogoMIPS: 7219.19
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology tsc_reliable nonstop_tsc cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves avx_vnni vnmi umip waitpkg gfni vaes vpclmulqdq rdpid movdiri movdir64b fsrm md_clear serialize flush_l1d arch_capabilities
Virtualization: VT-x
Hypervisor vendor: Microsoft
Virtualization type: full
L1d cache: 480 KiB (10 instances)
L1i cache: 320 KiB (10 instances)
L2 cache: 12.5 MiB (10 instances)
L3 cache: 25 MiB (1 instance)
NUMA node(s): 1
NUMA node0 CPU(s): 0-19
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Mitigation; Clear Register File
Vulnerability Retbleed: Mitigation; Enhanced IBRS
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI BHI_DIS_S
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] Could not collect
[conda] Could not collect
cc @jerryzh168 @ptrblck @msaroufim @eqy @tinglvv @nWEIdia @jianyuh @nikitaved @mruberry @walterddr @xwang233 @lezcano
π Describe the bug
torch.linalg.eighsuffers a catastrophic performance cliff when the matrix size crosses n=32 β 33 for batched inputs on CUDA. I find the root cause to be a probably outdated matrix size gate in PyTorch's dispatch logic that only routes to the batched cuSOLVER API (cusolverDnXsyevBatched) for n β€ 32, falling back to a sequential for-loop over individual matrices for n > 32.Minimal reproducer
Output (PyTorch 2.10.0+cu126, RTX 4080):
n=32 β n=33 jumps from 1.65 ms to 137 ms β an ~83x cliff.
Root cause
The dispatch lives in
aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.cpp, functionlinalg_eigh_cusolver(~line 1614):linalg_eigh_cusolver_syevj_batched(), which (since PR [ATen][CUDA][cuSOLVER] Add cusolverDnXsyevBatched for torch.linalg.eighΒ #155695) callscusolverDnXsyevBatchedβ a batched API that works efficiently for all matrix sizes.syevjorsyevdbranches, which solve one matrix at a time in a loop with separate kernel launches.The
<= 32gate is a leftover from when the oldsyevjBatchedJacobi API performed poorly for larger matrices (see PR #53040). PR #155695 (June 2025) replaced the underlying call with the newercusolverDnXsyevBatchedwhich in my test performs much better for larger matrix now, but the top-level dispatch condition was never updated.Proposed fix
Remove the
<= 32size gate so all batchedeighcalls route throughcusolverDnXsyevBatched:Benchmark results
I verified the fix by building PyTorch from source with only this one-line change. All benchmarks on NVIDIA GeForce RTX 4080, CUDA 12.8, Ubuntu 22.04.5 LTS (WSL2). Timing via CUDA events (3 warmup, 5 timed iterations).
Main result: B=128, float64
Consistent speedup for n > 32, no regression for n β€ 32.
Environment
Related
syevjBatchedwithcusolverDnXsyevBatched(but kept the<= 32gate)<= 32heuristicVersions
PyTorch version: 2.10.0+cu126
Is debug build: False
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04.2) 11.4.0
Clang version: Could not collect
CMake version: version 4.1.0
Libc version: glibc-2.35
Python version: 3.12.9 (main, Feb 12 2025, 14:50:50) [Clang 19.1.6 ] (64-bit runtime)
Python platform: Linux-6.6.87.2-microsoft-standard-WSL2-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.8.61
CUDA_MODULE_LOADING set to:
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 4080
Nvidia driver version: 591.74
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 20
On-line CPU(s) list: 0-19
Vendor ID: GenuineIntel
Model name: 12th Gen Intel(R) Core(TM) i7-12700K
CPU family: 6
Model: 151
Thread(s) per core: 2
Core(s) per socket: 10
Socket(s): 1
Stepping: 2
BogoMIPS: 7219.19
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology tsc_reliable nonstop_tsc cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves avx_vnni vnmi umip waitpkg gfni vaes vpclmulqdq rdpid movdiri movdir64b fsrm md_clear serialize flush_l1d arch_capabilities
Virtualization: VT-x
Hypervisor vendor: Microsoft
Virtualization type: full
L1d cache: 480 KiB (10 instances)
L1i cache: 320 KiB (10 instances)
L2 cache: 12.5 MiB (10 instances)
L3 cache: 25 MiB (1 instance)
NUMA node(s): 1
NUMA node0 CPU(s): 0-19
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Mitigation; Clear Register File
Vulnerability Retbleed: Mitigation; Enhanced IBRS
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI BHI_DIS_S
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] Could not collect
[conda] Could not collect
cc @jerryzh168 @ptrblck @msaroufim @eqy @tinglvv @nWEIdia @jianyuh @nikitaved @mruberry @walterddr @xwang233 @lezcano