[MPS] Implement hardshrink metal kernel#155304
[MPS] Implement hardshrink metal kernel#155304manuelcandales wants to merge 13 commits intogh/manuelcandales/2/basefrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/155304
Note: Links to docs will display an error until the docs builds have been completed. ⏳ 1 Pending, 9 Unrelated FailuresAs of commit 525ae01 with merge base d4d0ede ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
BROKEN TRUNK - The following jobs failed but was present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Attention! native_functions.yaml was changedIf you are adding a new function or defaulted argument to native_functions.yaml, you cannot use it from pre-existing Python frontend code until our FC window passes (two weeks). Split your PR into two PRs, one which adds the new C++ functionality, and one that makes use of it from Python, and land them two weeks apart. See https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#forwards-compatibility-fc for more info. Caused by: |
malfet
left a comment
There was a problem hiding this comment.
type API looks a bit fragile, consider using std::pair and encode extra dtype in shader name
| const std::optional<c10::Scalar> alpha = std::nullopt, | ||
| const std::optional<c10::ScalarType> scalar_arg_type = std::nullopt); |
There was a problem hiding this comment.
This API looks wrong, it should be something like
| const std::optional<c10::Scalar> alpha = std::nullopt, | |
| const std::optional<c10::ScalarType> scalar_arg_type = std::nullopt); | |
| const std::optional<std::pair<c10::Scalar, c10::ScalarType>> alpha = std::nullopt, |
malfet
left a comment
There was a problem hiding this comment.
Thank you for working on the change
|
@pytorchbot merge -f "Lint + MPS are green, will submit some cleanups later on" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…155316) Pull Request resolved: #155316 Approved by: https://github.com/Skylion007, https://github.com/malfet ghstack dependencies: #155304
…5462) Pull Request resolved: #155462 Approved by: https://github.com/malfet ghstack dependencies: #155304, #155316
Pull Request resolved: #155479 Approved by: https://github.com/kulinseth, https://github.com/malfet ghstack dependencies: #155304, #155316, #155462
Stack from ghstack (oldest at bottom):
Implements the forward and backward hardshrink operators as Metal kernels.
In order to support the lambda parameter, we extend the
exec_unary_kernelandexec_binary_kernelmethods. Now they take an optional Scalar and an optional ScalarType argument. When the optional ScalarType is provided, it overrides the type of the Scalar.We add a new
REGISTER_UNARY_ALPHA_OPmacro, and modify the existingREGISTER_BINARY_ALPHA_OPto support the new feature.