Skip to content

[MPS] Fix SDPA output shape when value head dim differs#176843

Closed
hvaara wants to merge 1 commit intopytorch:mainfrom
hvaara:mps-sdpa-ev-shape-fix
Closed

[MPS] Fix SDPA output shape when value head dim differs#176843
hvaara wants to merge 1 commit intopytorch:mainfrom
hvaara:mps-sdpa-ev-shape-fix

Conversation

@hvaara
Copy link
Contributor

@hvaara hvaara commented Mar 8, 2026

This fixes MPS SDPA output shape for cases where value.size(-1) != query.size(-1), so output now follows (..., L, Ev) as expected. I also added guards in Metal kernel paths that assume equal qkv head dims.

Added the updated meta shape inference for the sdpa_general_mps path which seems to have been left out initially.

Added regression coverage in test/test_transformers.py covering the shape semantics, and a similar one in test/test_mps.py that also checks for numerical parity with CPU.

Fixes #176767

@hvaara hvaara requested a review from malfet as a code owner March 8, 2026 20:22
@pytorch-bot
Copy link

pytorch-bot bot commented Mar 8, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/176843

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit cf60249 with merge base 7643509 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@malfet
Copy link
Contributor

malfet commented Mar 9, 2026

@pytorchbot merge -f "Lint + MPS is green"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/mps Run MPS tests (subset of trunk) Merged open source release notes: mps Release notes category topic: bug fixes topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

MPS: scaled_dot_product_attention returns wrong output shape when value dim != query/key dim

4 participants