Skip to content

[MPS] Fix masked_scatter to preserve scalar tensor shape#174381

Closed
vmoens wants to merge 3 commits intopytorch:mainfrom
vmoens:fix-mps-masked-scatter-scalar-shape
Closed

[MPS] Fix masked_scatter to preserve scalar tensor shape#174381
vmoens wants to merge 3 commits intopytorch:mainfrom
vmoens:fix-mps-masked-scatter-scalar-shape

Conversation

@vmoens
Copy link
Contributor

@vmoens vmoens commented Feb 5, 2026

Summary

  • Fix MPS masked_scatter to preserve scalar tensor shape [] instead of incorrectly returning [1]
  • The function now records whether self was originally 0-dimensional and squeezes the result back after processing

Test plan

  • Added scalar tensor test case to test_masked_scatter in test/test_mps.py
  • Verified fix with reproducer script

Related error: pytorch/rl#3137 (comment)

@pytorch-bot
Copy link

pytorch-bot bot commented Feb 5, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/174381

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit baedb21 with merge base 14f828c (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: mps Release notes category label Feb 5, 2026
@linux-foundation-easycla
Copy link

linux-foundation-easycla bot commented Feb 5, 2026

CLA Signed

The committers listed above are authorized under a signed CLA.

@vmoens vmoens force-pushed the fix-mps-masked-scatter-scalar-shape branch from 8813ad6 to 9164d85 Compare February 5, 2026 15:05
@zou3519 zou3519 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Feb 6, 2026
@vmoens
Copy link
Contributor Author

vmoens commented Feb 16, 2026

Hello @kulinseth can you give a look at this?

Copy link
Collaborator

@kurtamohler kurtamohler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR Vincent!

Eventually, this op should probably be updated to avoid reshaping in all cases (currently fails an opinfo test for this reason), but this PR does fix it for the 0-dim case.

Could you please rebase?

@vmoens
Copy link
Contributor Author

vmoens commented Mar 4, 2026

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Rebase failed due to Command git -C /home/runner/work/pytorch/pytorch rebase refs/remotes/origin/viable/strict pull/174381/head returned non-zero exit code 1

Rebasing (1/3)
Auto-merging aten/src/ATen/native/mps/operations/Indexing.mm
CONFLICT (content): Merge conflict in aten/src/ATen/native/mps/operations/Indexing.mm
Auto-merging test/test_mps.py
CONFLICT (content): Merge conflict in test/test_mps.py
error: could not apply c1c9ed89fde... [MPS] Fix masked_scatter to preserve scalar tensor shape
hint: Resolve all conflicts manually, mark them as resolved with
hint: "git add/rm <conflicted_files>", then run "git rebase --continue".
hint: You can instead skip this commit: run "git rebase --skip".
hint: To abort and get back to the state before "git rebase", run "git rebase --abort".
hint: Disable this message with "git config set advice.mergeConflict false"
Could not apply c1c9ed89fde... # [MPS] Fix masked_scatter to preserve scalar tensor shape

Raised by https://github.com/pytorch/pytorch/actions/runs/22659993105

vmoens added 3 commits March 4, 2026 09:21
The MPS implementation of masked_scatter incorrectly changed scalar
tensor shape from [] to [1]. The function unsqueezes 0-d tensors for
processing via index_put_out, but was not restoring the original shape
after the operation.

The fix records whether self was originally 0-dimensional and squeezes
the result back to 0-d at the end if needed.

Authored with Claude.
@vmoens vmoens force-pushed the fix-mps-masked-scatter-scalar-shape branch from 84ca67c to baedb21 Compare March 4, 2026 09:29
@kurtamohler kurtamohler added the ciflow/mps Run MPS tests (subset of trunk) label Mar 4, 2026
@kurtamohler
Copy link
Collaborator

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 4, 2026
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: linux-aarch64 / linux-jammy-aarch64-py3.10 / test (openreg, 1, 1, linux.arm64.m7g.4xlarge)

Details for Dev Infra team Raised by workflow job

@kurtamohler
Copy link
Collaborator

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Vighaneshs pushed a commit to Vighaneshs/pytorch that referenced this pull request Mar 5, 2026
)

## Summary
- Fix MPS `masked_scatter` to preserve scalar tensor shape `[]` instead of incorrectly returning `[1]`
- The function now records whether `self` was originally 0-dimensional and squeezes the result back after processing

## Test plan
- Added scalar tensor test case to `test_masked_scatter` in `test/test_mps.py`
- Verified fix with reproducer script

Related error: pytorch/rl#3137 (comment)
Pull Request resolved: pytorch#174381
Approved by: https://github.com/kurtamohler
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/mps Run MPS tests (subset of trunk) ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: mps Release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants