[MPS] Fix masked_scatter to preserve scalar tensor shape#174381
[MPS] Fix masked_scatter to preserve scalar tensor shape#174381vmoens wants to merge 3 commits intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/174381
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit baedb21 with merge base 14f828c ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
8813ad6 to
9164d85
Compare
|
Hello @kulinseth can you give a look at this? |
|
@pytorchbot rebase |
|
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
|
Rebase failed due to Command Raised by https://github.com/pytorch/pytorch/actions/runs/22659993105 |
The MPS implementation of masked_scatter incorrectly changed scalar tensor shape from [] to [1]. The function unsqueezes 0-d tensors for processing via index_put_out, but was not restoring the original shape after the operation. The fix records whether self was originally 0-dimensional and squeezes the result back to 0-d at the end if needed. Authored with Claude.
84ca67c to
baedb21
Compare
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: linux-aarch64 / linux-jammy-aarch64-py3.10 / test (openreg, 1, 1, linux.arm64.m7g.4xlarge) Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
) ## Summary - Fix MPS `masked_scatter` to preserve scalar tensor shape `[]` instead of incorrectly returning `[1]` - The function now records whether `self` was originally 0-dimensional and squeezes the result back after processing ## Test plan - Added scalar tensor test case to `test_masked_scatter` in `test/test_mps.py` - Verified fix with reproducer script Related error: pytorch/rl#3137 (comment) Pull Request resolved: pytorch#174381 Approved by: https://github.com/kurtamohler
Summary
masked_scatterto preserve scalar tensor shape[]instead of incorrectly returning[1]selfwas originally 0-dimensional and squeezes the result back after processingTest plan
test_masked_scatterintest/test_mps.pyRelated error: pytorch/rl#3137 (comment)