[MPS] Extend atomic operations to all int types#158179
[MPS] Extend atomic operations to all int types#158179malfet wants to merge 2 commits intogh/malfet/437/basefrom
Conversation
The only ones that are not covered right now are int64 [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/158179
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 40 PendingAs of commit ac32251 with merge base dd93883 ( UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| union { | ||
| uint i; | ||
| T t[2]; | ||
| T t[elem_per_enum]; |
There was a problem hiding this comment.
Nit can this be a real array type since array is constexpr length?
There was a problem hiding this comment.
Do you mean metal::array? Or something like char4?
That fixes index_put for all dtypes but int64 (as Metal only has atomic loads and stores for this dtype) [ghstack-poisoned]
|
@pytorchbot merge -f "Lint + MPS are green" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
index_kernelfor large tensors #158064That fixes
index_put(..., accumulate=True)for all dtypesint64 operation is not really atomic, but eventually consistent from the
index_put_accumulatekernel point of view: i.e. by the end of the operation results in the global memory are indeed accumulation of the operands at given indices