Skip to content

[MPS] Fix crash when max/min ops called for complex types#166214

Closed
malfet wants to merge 5 commits intogh/malfet/571/basefrom
gh/malfet/571/head
Closed

[MPS] Fix crash when max/min ops called for complex types#166214
malfet wants to merge 5 commits intogh/malfet/571/basefrom
gh/malfet/571/head

Conversation

@malfet
Copy link
Copy Markdown
Contributor

@malfet malfet commented Oct 25, 2025

Stack from ghstack (oldest at bottom):

Raise an exception, as it's meaningless and results in segfault otherwise:

% python -c "import torch;torch.rand(10, dtype=torch.cfloat, device='mps').amax()"
(mpsFileLoc): /AppleInternal/Library/BuildRoots/4~B6shugDBannYeMBGCfhw7wjvNJOfy4BrawZ7TdI/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:176:0: error: 'mps.reduction_max' op operand #0 must be tensor of mps native type values, but got 'tensor<10xcomplex<f32>>'
(mpsFileLoc): /AppleInternal/Library/BuildRoots/4~B6shugDBannYeMBGCfhw7wjvNJOfy4BrawZ7TdI/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:176:0: note: see current operation: %2 = "mps.reduction_max"(%arg0, %1) <{keep_dims, propagate_nans}> : (tensor<10xcomplex<f32>>, tensor<1xsi32>) -> tensor<1xcomplex<f32>>
(mpsFileLoc): /AppleInternal/Library/BuildRoots/4~B6shugDBannYeMBGCfhw7wjvNJOfy4BrawZ7TdI/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:176:0: error: 'mps.reduction_max' op operand #0 must be tensor of mps native type values, but got 'tensor<10xcomplex<f32>>'
(mpsFileLoc): /AppleInternal/Library/BuildRoots/4~B6shugDBannYeMBGCfhw7wjvNJOfy4BrawZ7TdI/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:176:0: note: see current operation: %2 = "mps.reduction_max"(%arg0, %1) <{keep_dims, propagate_nans}> : (tensor<10xcomplex<f32>>, tensor<1xsi32>) -> tensor<1xcomplex<f32>>
/AppleInternal/Library/BuildRoots/4~B6shugDBannYeMBGCfhw7wjvNJOfy4BrawZ7TdI/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphExecutable.mm:1347: failed assertion `original module failed verification'
zsh: abort      python -c 

To be tested by test_ops.py

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Oct 25, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/166214

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit b43a83f with merge base d049ed2 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot Bot added ciflow/mps Run MPS tests (subset of trunk) release notes: mps Release notes category labels Oct 25, 2025
@malfet malfet requested review from Skylion007 and dcci October 25, 2025 01:46
@malfet malfet added the topic: bug fixes topic category label Oct 25, 2025
@malfet malfet changed the title [MPS] Fix crash when reduce ops called for scalar dtypes [MPS] Fix crash when max/min ops called for complex types Oct 25, 2025
[ghstack-poisoned]
malfet and others added 3 commits October 26, 2025 18:46
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@malfet
Copy link
Copy Markdown
Contributor Author

malfet commented Oct 31, 2025

@pytorchbot merge -f "Everything is green"

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Oct 31, 2025
And enable fp16 implementation for CPU, which simplifies OpInfo definitions for the op

Pull Request resolved: #166687
Approved by: https://github.com/Skylion007
ghstack dependencies: #166214
pytorchmergebot pushed a commit that referenced this pull request Oct 31, 2025
Or BatchNorm or LayerNorm for Long types

Discovered while trying to enable `test_ops.py` for MPS
Pull Request resolved: #166215
Approved by: https://github.com/dcci, https://github.com/kulinseth, https://github.com/Skylion007
ghstack dependencies: #166214, #166687
BoyuanFeng pushed a commit that referenced this pull request Oct 31, 2025
Raise an exception, as it's meaningless and results in segfault otherwise:
```
% python -c "import torch;torch.rand(10, dtype=torch.cfloat, device='mps').amax()"
(mpsFileLoc): /AppleInternal/Library/BuildRoots/4~B6shugDBannYeMBGCfhw7wjvNJOfy4BrawZ7TdI/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:176:0: error: 'mps.reduction_max' op operand #0 must be tensor of mps native type values, but got 'tensor<10xcomplex<f32>>'
(mpsFileLoc): /AppleInternal/Library/BuildRoots/4~B6shugDBannYeMBGCfhw7wjvNJOfy4BrawZ7TdI/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:176:0: note: see current operation: %2 = "mps.reduction_max"(%arg0, %1) <{keep_dims, propagate_nans}> : (tensor<10xcomplex<f32>>, tensor<1xsi32>) -> tensor<1xcomplex<f32>>
(mpsFileLoc): /AppleInternal/Library/BuildRoots/4~B6shugDBannYeMBGCfhw7wjvNJOfy4BrawZ7TdI/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:176:0: error: 'mps.reduction_max' op operand #0 must be tensor of mps native type values, but got 'tensor<10xcomplex<f32>>'
(mpsFileLoc): /AppleInternal/Library/BuildRoots/4~B6shugDBannYeMBGCfhw7wjvNJOfy4BrawZ7TdI/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:176:0: note: see current operation: %2 = "mps.reduction_max"(%arg0, %1) <{keep_dims, propagate_nans}> : (tensor<10xcomplex<f32>>, tensor<1xsi32>) -> tensor<1xcomplex<f32>>
/AppleInternal/Library/BuildRoots/4~B6shugDBannYeMBGCfhw7wjvNJOfy4BrawZ7TdI/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphExecutable.mm:1347: failed assertion `original module failed verification'
zsh: abort      python -c
```

To be tested by `test_ops.py`
Pull Request resolved: #166214
Approved by: https://github.com/dcci, https://github.com/kulinseth, https://github.com/Skylion007
ghstack dependencies: #166272
BoyuanFeng pushed a commit that referenced this pull request Oct 31, 2025
And enable fp16 implementation for CPU, which simplifies OpInfo definitions for the op

Pull Request resolved: #166687
Approved by: https://github.com/Skylion007
ghstack dependencies: #166214
pytorchmergebot pushed a commit that referenced this pull request Nov 3, 2025
Or BatchNorm or LayerNorm for Long types

Discovered while trying to enable `test_ops.py` for MPS
Pull Request resolved: #166215
Approved by: https://github.com/dcci, https://github.com/kulinseth, https://github.com/Skylion007
ghstack dependencies: #166214
etaf pushed a commit to etaf/pytorch-inductor-xpu that referenced this pull request Nov 4, 2025
…6214)

Raise an exception, as it's meaningless and results in segfault otherwise:
```
% python -c "import torch;torch.rand(10, dtype=torch.cfloat, device='mps').amax()"
(mpsFileLoc): /AppleInternal/Library/BuildRoots/4~B6shugDBannYeMBGCfhw7wjvNJOfy4BrawZ7TdI/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:176:0: error: 'mps.reduction_max' op operand #0 must be tensor of mps native type values, but got 'tensor<10xcomplex<f32>>'
(mpsFileLoc): /AppleInternal/Library/BuildRoots/4~B6shugDBannYeMBGCfhw7wjvNJOfy4BrawZ7TdI/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:176:0: note: see current operation: %2 = "mps.reduction_max"(%arg0, %1) <{keep_dims, propagate_nans}> : (tensor<10xcomplex<f32>>, tensor<1xsi32>) -> tensor<1xcomplex<f32>>
(mpsFileLoc): /AppleInternal/Library/BuildRoots/4~B6shugDBannYeMBGCfhw7wjvNJOfy4BrawZ7TdI/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:176:0: error: 'mps.reduction_max' op operand #0 must be tensor of mps native type values, but got 'tensor<10xcomplex<f32>>'
(mpsFileLoc): /AppleInternal/Library/BuildRoots/4~B6shugDBannYeMBGCfhw7wjvNJOfy4BrawZ7TdI/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:176:0: note: see current operation: %2 = "mps.reduction_max"(%arg0, %1) <{keep_dims, propagate_nans}> : (tensor<10xcomplex<f32>>, tensor<1xsi32>) -> tensor<1xcomplex<f32>>
/AppleInternal/Library/BuildRoots/4~B6shugDBannYeMBGCfhw7wjvNJOfy4BrawZ7TdI/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphExecutable.mm:1347: failed assertion `original module failed verification'
zsh: abort      python -c
```

To be tested by `test_ops.py`
Pull Request resolved: pytorch#166214
Approved by: https://github.com/dcci, https://github.com/kulinseth, https://github.com/Skylion007
ghstack dependencies: pytorch#166272
etaf pushed a commit to etaf/pytorch-inductor-xpu that referenced this pull request Nov 4, 2025
And enable fp16 implementation for CPU, which simplifies OpInfo definitions for the op

Pull Request resolved: pytorch#166687
Approved by: https://github.com/Skylion007
ghstack dependencies: pytorch#166214
etaf pushed a commit to etaf/pytorch-inductor-xpu that referenced this pull request Nov 4, 2025
Or BatchNorm or LayerNorm for Long types

Discovered while trying to enable `test_ops.py` for MPS
Pull Request resolved: pytorch#166215
Approved by: https://github.com/dcci, https://github.com/kulinseth, https://github.com/Skylion007
ghstack dependencies: pytorch#166214, pytorch#166687
pytorch-bot Bot pushed a commit that referenced this pull request Nov 4, 2025
Or BatchNorm or LayerNorm for Long types

Discovered while trying to enable `test_ops.py` for MPS
Pull Request resolved: #166215
Approved by: https://github.com/dcci, https://github.com/kulinseth, https://github.com/Skylion007
ghstack dependencies: #166214
Khanaksahu pushed a commit to Khanaksahu/pytorch that referenced this pull request Nov 17, 2025
Assert as it's meaningless otherwise

ghstack-source-id: 2d9be90
Pull Request resolved: pytorch/pytorch#166214
Khanaksahu pushed a commit to Khanaksahu/pytorch that referenced this pull request Nov 17, 2025
Assert as it's meaningless otherwise

ghstack-source-id: 2d9be90
Pull Request resolved: pytorch/pytorch#166214
@github-actions github-actions Bot deleted the gh/malfet/571/head branch December 1, 2025 02:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/mps Run MPS tests (subset of trunk) Merged release notes: mps Release notes category topic: bug fixes topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants