[MPS] Add native implementation for shift ops#131813
Conversation
Similar to how AND/OR/XOR ops are implemented TODO: Consider using MPS method calls rather than metal kernels
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/131813
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit d1c7b2a with merge base bf6aae1 ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| // For the rest of unsupported ops the user needs to pass 'PYTORCH_ENABLE_MPS_FALLBACK=1' | ||
| // to fallback on CPU, otherwise we will error out. | ||
| m.impl("bitwise_left_shift.Tensor_out", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>()); | ||
| m.impl("bitwise_right_shift.Tensor_out", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>()); |
There was a problem hiding this comment.
If there was a fallback in place, why were things not working before?
Granted that the error was with __left__.Scalar. Not sure how that magic works
There was a problem hiding this comment.
Oh, the fallback was for the Tensor variant, not for the Scalar ones
There was a problem hiding this comment.
Yeah, this only works if one calls torch.bitwise_left_shift, not sure what's the story with all those
| variants: method, function | ||
| dispatch: | ||
| CPU, CUDA: __lshift__ | ||
| CPU, CUDA, MPS: __lshift__ |
There was a problem hiding this comment.
Just curious, where is the magic that ties __lshift__ to bitwise_left_shift_out
There was a problem hiding this comment.
Here:
pytorch/aten/src/ATen/native/BinaryOps.cpp
Line 1292 in 3d7c424
|
@pytorchbot merge -f "MPS + Linr tests are green" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Missed it while working on #131813 Test plan: `python -c "import torch;print(torch.randint(100, 500, (64,), device='mps') >> torch.tensor([3,], device='mps'))"` Pull Request resolved: #135607 Approved by: https://github.com/manuelcandales
Missed it while working on pytorch#131813 Test plan: `python -c "import torch;print(torch.randint(100, 500, (64,), device='mps') >> torch.tensor([3,], device='mps'))"` Pull Request resolved: pytorch#135607 Approved by: https://github.com/manuelcandales
Missed it while working on #131813 Test plan: `python -c "import torch;print(torch.randint(100, 500, (64,), device='mps') >> torch.tensor([3,], device='mps'))"` Pull Request resolved: #135607 Approved by: https://github.com/manuelcandales (cherry picked from commit 3bf6be4)
[MPS] Add missing dispatch to rshift.Tensor (#135607) Missed it while working on #131813 Test plan: `python -c "import torch;print(torch.randint(100, 500, (64,), device='mps') >> torch.tensor([3,], device='mps'))"` Pull Request resolved: #135607 Approved by: https://github.com/manuelcandales (cherry picked from commit 3bf6be4) Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Similar to how AND/OR/XOR ops are implemented
TODO: Consider using MPS method calls rather than metal kernels