-
Notifications
You must be signed in to change notification settings - Fork 27.4k
[MPS] linalg.solve backward gives wrong gradients for both A and B #175192
Description
🐛 Describe the bug
torch.linalg.solve forward on MPS matches CPU closely, but backward produces wrong gradients for both the matrix and the right-hand side. This happens on contiguous inputs too, so it's not a stride issue.
import torch
torch.manual_seed(0)
A = (torch.randn(3, 4, 4) + 3 * torch.eye(4).unsqueeze(0)).requires_grad_(True)
B = torch.randn(3, 4, 2).requires_grad_(True)
A_mps = A.detach().to("mps").requires_grad_(True)
B_mps = B.detach().to("mps").requires_grad_(True)
torch.linalg.solve(A, B).sum().backward()
torch.linalg.solve(A_mps, B_mps).sum().backward()
torch.mps.synchronize()
print("A.grad diff:", (A.grad - A_mps.grad.cpu()).abs().max().item())
print("B.grad diff:", (B.grad - B_mps.grad.cpu()).abs().max().item())Output:
A.grad diff: 14.83
B.grad diff: 1.45
I tested 10 different seeds and every one is wrong. The error varies a lot — well-conditioned matrices might be off by ~0.2 while near-singular ones can be off by thousands (seed=1 gives ig_A=13820).
CUDA is not affected (tested on A100 with 50 seeds, all within fp32 tolerance).
PYTORCH_ENABLE_MPS_FALLBACK=1 works as a workaround since it falls back to CPU for the backward.
Reproduces on 2.10.0 and nightly 2.12.0.dev20260217.
Versions
Collecting environment information...
PyTorch version: 2.10.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 14.8.3 (arm64)
GCC version: Could not collect
Clang version: 16.0.0 (clang-1600.0.26.6)
CMake version: version 4.2.3
Libc version: N/A
Python version: 3.13.11 (main, Dec 5 2025, 16:06:33) [Clang 16.0.0 (clang-1600.0.26.6)] (64-bit runtime)
Python platform: macOS-14.8.3-arm64-arm-64bit-Mach-O
Is CUDA available: False
Is XPU available: False
Is XNNPACK available: True
CPU:
Apple M2 Max
Also reproduces on nightly 2.12.0.dev20260217 (same machine).
cc @jianyuh @nikitaved @mruberry @walterddr @xwang233 @lezcano @kulinseth @malfet @DenisVieriu97 @jhavukainen @aditvenk