Skip to content

[MPS] linalg.solve backward gives wrong gradients for both A and B #175192

@npinto

Description

@npinto

🐛 Describe the bug

torch.linalg.solve forward on MPS matches CPU closely, but backward produces wrong gradients for both the matrix and the right-hand side. This happens on contiguous inputs too, so it's not a stride issue.

import torch

torch.manual_seed(0)
A = (torch.randn(3, 4, 4) + 3 * torch.eye(4).unsqueeze(0)).requires_grad_(True)
B = torch.randn(3, 4, 2).requires_grad_(True)

A_mps = A.detach().to("mps").requires_grad_(True)
B_mps = B.detach().to("mps").requires_grad_(True)

torch.linalg.solve(A, B).sum().backward()
torch.linalg.solve(A_mps, B_mps).sum().backward()
torch.mps.synchronize()

print("A.grad diff:", (A.grad - A_mps.grad.cpu()).abs().max().item())
print("B.grad diff:", (B.grad - B_mps.grad.cpu()).abs().max().item())

Output:

A.grad diff: 14.83
B.grad diff: 1.45

I tested 10 different seeds and every one is wrong. The error varies a lot — well-conditioned matrices might be off by ~0.2 while near-singular ones can be off by thousands (seed=1 gives ig_A=13820).

CUDA is not affected (tested on A100 with 50 seeds, all within fp32 tolerance).

PYTORCH_ENABLE_MPS_FALLBACK=1 works as a workaround since it falls back to CPU for the backward.

Reproduces on 2.10.0 and nightly 2.12.0.dev20260217.

Versions

Collecting environment information...
PyTorch version: 2.10.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 14.8.3 (arm64)
GCC version: Could not collect
Clang version: 16.0.0 (clang-1600.0.26.6)
CMake version: version 4.2.3
Libc version: N/A

Python version: 3.13.11 (main, Dec  5 2025, 16:06:33) [Clang 16.0.0 (clang-1600.0.26.6)] (64-bit runtime)
Python platform: macOS-14.8.3-arm64-arm-64bit-Mach-O
Is CUDA available: False
Is XPU available: False
Is XNNPACK available: True

CPU:
Apple M2 Max

Also reproduces on nightly 2.12.0.dev20260217 (same machine).

cc @jianyuh @nikitaved @mruberry @walterddr @xwang233 @lezcano @kulinseth @malfet @DenisVieriu97 @jhavukainen @aditvenk

Metadata

Metadata

Assignees

Labels

bot-mislabeledIf you notice that the claude-bot mislabeled an issuebot-triagedThis is a label only to be used by the auto triage botmodule: linear algebraIssues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmulmodule: mpsRelated to Apple Metal Performance Shaders frameworktriage review

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions