[MPS] Fix index_kernel for large tensors#158064
[MPS] Fix index_kernel for large tensors#158064malfet wants to merge 6 commits intogh/malfet/435/basefrom
index_kernel for large tensors#158064Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/158064
Note: Links to docs will display an error until the docs builds have been completed. ⏳ 1 Pending, 1 Unrelated FailureAs of commit 3a193a4 with merge base dd93883 ( UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
[ghstack-poisoned]
Move `MetalShaderLibrary::bind_tensors` private method to OperatorUtils.h and extract `iter_tensor_offset` method, that returns an offset from the start of the storage associated with given tensor inside the iterator Fixes #153560 [ghstack-poisoned]
Move `MetalShaderLibrary::bind_tensors` private method to OperatorUtils.h and extract `iter_tensor_offset` method, that returns an offset from the start of the storage associated with given tensor inside the iterator
Migrated `index`, `index_put[_accumulate][_serial]` to the new paradigm that does not require additional tensor for indices nor special handling for 32 vs 64-bit offset, which resulted in almost 2x perf gain for 2000x2000 tensor, see results below before
```
[------------------------------------------------------------ -----------------------------------------------------------]
| 11x50x50 | 11x100x100 | 11x500x500 | 11x1000x1000 | 11x2000x2000
1 threads: ----------------------------------------------------------------------------------------------------------------
__getitem__ (torch.int8, torch.int64) | 383.5 | 379.8 | 470.9 | 1232.9 | 4410.3
__getitem__ (torch.float16, torch.int64) | 379.6 | 354.5 | 533.2 | 1290.3 | 4442.2
__getitem__ (torch.float32, torch.int64) | 360.8 | 338.6 | 478.6 | 1348.9 | 4870.4
Times are in microseconds (us).
```
and after
```
[------------------------------------------------------------ -----------------------------------------------------------]
| 11x50x50 | 11x100x100 | 11x500x500 | 11x1000x1000 | 11x2000x2000
1 threads: ----------------------------------------------------------------------------------------------------------------
__getitem__ (torch.int8, torch.int64) | 349.8 | 330.5 | 432.6 | 764.5 | 1961.2
__getitem__ (torch.float16, torch.int64) | 342.5 | 330.7 | 434.7 | 741.0 | 1969.4
__getitem__ (torch.float32, torch.int64) | 332.2 | 326.1 | 445.4 | 751.3 | 1972.6
Times are in microseconds (us).
```
Fixes #153560
[ghstack-poisoned]
Move `MetalShaderLibrary::bind_tensors` private method to OperatorUtils.h and extract `iter_tensor_offset` method, that returns an offset from the start of the storage associated with given tensor inside the iterator
Migrated `index`, `index_put[_accumulate][_serial]` to the new paradigm that does not require additional tensor for indices nor special handling for 32 vs 64-bit offset, which resulted in almost 2x perf gain for 2000x2000 tensor, see results below before
```
[------------------------------------------------------------ -----------------------------------------------------------]
| 11x50x50 | 11x100x100 | 11x500x500 | 11x1000x1000 | 11x2000x2000
1 threads: ----------------------------------------------------------------------------------------------------------------
__getitem__ (torch.int8, torch.int64) | 383.5 | 379.8 | 470.9 | 1232.9 | 4410.3
__getitem__ (torch.float16, torch.int64) | 379.6 | 354.5 | 533.2 | 1290.3 | 4442.2
__getitem__ (torch.float32, torch.int64) | 360.8 | 338.6 | 478.6 | 1348.9 | 4870.4
Times are in microseconds (us).
```
and after
```
[------------------------------------------------------------ -----------------------------------------------------------]
| 11x50x50 | 11x100x100 | 11x500x500 | 11x1000x1000 | 11x2000x2000
1 threads: ----------------------------------------------------------------------------------------------------------------
__getitem__ (torch.int8, torch.int64) | 349.8 | 330.5 | 432.6 | 764.5 | 1961.2
__getitem__ (torch.float16, torch.int64) | 342.5 | 330.7 | 434.7 | 741.0 | 1969.4
__getitem__ (torch.float32, torch.int64) | 332.2 | 326.1 | 445.4 | 751.3 | 1972.6
Times are in microseconds (us).
```
Fixes #153560
[ghstack-poisoned]
Move `MetalShaderLibrary::bind_tensors` private method to OperatorUtils.h and extract `iter_tensor_offset` method, that returns an offset from the start of the storage associated with given tensor inside the iterator
Migrated `index`, `index_put[_accumulate][_serial]` to the new paradigm that does not require additional tensor for indices nor special handling for 32 vs 64-bit offset, which resulted in almost 2x perf gain for 2000x2000 tensor, see results below before
```
[------------------------------------------------------------ -----------------------------------------------------------]
| 11x50x50 | 11x100x100 | 11x500x500 | 11x1000x1000 | 11x2000x2000
1 threads: ----------------------------------------------------------------------------------------------------------------
__getitem__ (torch.int8, torch.int64) | 383.5 | 379.8 | 470.9 | 1232.9 | 4410.3
__getitem__ (torch.float16, torch.int64) | 379.6 | 354.5 | 533.2 | 1290.3 | 4442.2
__getitem__ (torch.float32, torch.int64) | 360.8 | 338.6 | 478.6 | 1348.9 | 4870.4
Times are in microseconds (us).
```
and after
```
[------------------------------------------------------------ -----------------------------------------------------------]
| 11x50x50 | 11x100x100 | 11x500x500 | 11x1000x1000 | 11x2000x2000
1 threads: ----------------------------------------------------------------------------------------------------------------
__getitem__ (torch.int8, torch.int64) | 349.8 | 330.5 | 432.6 | 764.5 | 1961.2
__getitem__ (torch.float16, torch.int64) | 342.5 | 330.7 | 434.7 | 741.0 | 1969.4
__getitem__ (torch.float32, torch.int64) | 332.2 | 326.1 | 445.4 | 751.3 | 1972.6
Times are in microseconds (us).
```
While migrating also fixed index_put_accumulate for boolean types, by using compare_and_exchange trick over uint
Fixes #153560
[ghstack-poisoned]
Benchmark before
```
[------------------------------------------------------------ -----------------------------------------------------------]
| 11x50x50 | 11x100x100 | 11x500x500 | 11x1000x1000 | 11x2000x2000
1 threads: ----------------------------------------------------------------------------------------------------------------
__getitem__ (torch.int8, torch.int64) | 383.5 | 379.8 | 470.9 | 1232.9 | 4410.3
__getitem__ (torch.float16, torch.int64) | 379.6 | 354.5 | 533.2 | 1290.3 | 4442.2
__getitem__ (torch.float32, torch.int64) | 360.8 | 338.6 | 478.6 | 1348.9 | 4870.4
Times are in microseconds (us).
```
After
```
[------------------------------------------------------------ -----------------------------------------------------------]
| 11x50x50 | 11x100x100 | 11x500x500 | 11x1000x1000 | 11x2000x2000
1 threads: ----------------------------------------------------------------------------------------------------------------
__getitem__ (torch.int8, torch.int64) | 349.8 | 330.5 | 432.6 | 764.5 | 1961.2
__getitem__ (torch.float16, torch.int64) | 342.5 | 330.7 | 434.7 | 741.0 | 1969.4
__getitem__ (torch.float32, torch.int64) | 332.2 | 326.1 | 445.4 | 751.3 | 1972.6
Times are in microseconds (us).
```
ghstack-source-id: 6301e68
Pull Request resolved: #158064
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
@pytorchbot merge -f "Lint + MPS are green" |
|
The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Namely `index_get_offsets`, giving thread index computes offsets into input, output and indices tensors And `index_apply_indices` applies offests to either input or output tensor index Pull Request resolved: #158178 Approved by: https://github.com/dcci, https://github.com/Skylion007 ghstack dependencies: #158064
That fixes `index_put(..., accumulate=True)` for all dtypes int64 operation is not really atomic, but eventually consistent from the `index_put_accumulate` kernel point of view: i.e. by the end of the operation results in the global memory are indeed accumulation of the operands at given indices Pull Request resolved: #158179 Approved by: https://github.com/dcci, https://github.com/Skylion007 ghstack dependencies: #158064, #158178
|
@pytorchbot cherry-pick --onto release/2.8 -c critical |
Move `MetalShaderLibrary::bind_tensors` private method to OperatorUtils.h and extract `iter_tensor_offset` method, that returns an offset from the start of the storage associated with given tensor inside the iterator
Migrated `index`, `index_put[_accumulate][_serial]` to the new paradigm that does not require additional tensor for indices nor special handling for 32 vs 64-bit offset, which resulted in almost 2x perf gain for 2000x2000 tensor, see results below before
```
[------------------------------------------------------------ -----------------------------------------------------------]
| 11x50x50 | 11x100x100 | 11x500x500 | 11x1000x1000 | 11x2000x2000
1 threads: ----------------------------------------------------------------------------------------------------------------
__getitem__ (torch.int8, torch.int64) | 383.5 | 379.8 | 470.9 | 1232.9 | 4410.3
__getitem__ (torch.float16, torch.int64) | 379.6 | 354.5 | 533.2 | 1290.3 | 4442.2
__getitem__ (torch.float32, torch.int64) | 360.8 | 338.6 | 478.6 | 1348.9 | 4870.4
Times are in microseconds (us).
```
and after
```
[------------------------------------------------------------ -----------------------------------------------------------]
| 11x50x50 | 11x100x100 | 11x500x500 | 11x1000x1000 | 11x2000x2000
1 threads: ----------------------------------------------------------------------------------------------------------------
__getitem__ (torch.int8, torch.int64) | 349.8 | 330.5 | 432.6 | 764.5 | 1961.2
__getitem__ (torch.float16, torch.int64) | 342.5 | 330.7 | 434.7 | 741.0 | 1969.4
__getitem__ (torch.float32, torch.int64) | 332.2 | 326.1 | 445.4 | 751.3 | 1972.6
Times are in microseconds (us).
```
While migrating also fixed index_put_accumulate for boolean types, by using compare_and_exchange trick over uint
Fixes #153560
Pull Request resolved: #158064
Approved by: https://github.com/dcci
(cherry picked from commit beed033)
Cherry picking #158064The cherry pick PR is at #158239 and it is recommended to link a critical cherry pick PR with an issue. The following tracker issues are updated: Details for Dev Infra teamRaised by workflow job |
[MPS] Fix `index_kernel` for large tensors (#158064) Move `MetalShaderLibrary::bind_tensors` private method to OperatorUtils.h and extract `iter_tensor_offset` method, that returns an offset from the start of the storage associated with given tensor inside the iterator Migrated `index`, `index_put[_accumulate][_serial]` to the new paradigm that does not require additional tensor for indices nor special handling for 32 vs 64-bit offset, which resulted in almost 2x perf gain for 2000x2000 tensor, see results below before ``` [------------------------------------------------------------ -----------------------------------------------------------] | 11x50x50 | 11x100x100 | 11x500x500 | 11x1000x1000 | 11x2000x2000 1 threads: ---------------------------------------------------------------------------------------------------------------- __getitem__ (torch.int8, torch.int64) | 383.5 | 379.8 | 470.9 | 1232.9 | 4410.3 __getitem__ (torch.float16, torch.int64) | 379.6 | 354.5 | 533.2 | 1290.3 | 4442.2 __getitem__ (torch.float32, torch.int64) | 360.8 | 338.6 | 478.6 | 1348.9 | 4870.4 Times are in microseconds (us). ``` and after ``` [------------------------------------------------------------ -----------------------------------------------------------] | 11x50x50 | 11x100x100 | 11x500x500 | 11x1000x1000 | 11x2000x2000 1 threads: ---------------------------------------------------------------------------------------------------------------- __getitem__ (torch.int8, torch.int64) | 349.8 | 330.5 | 432.6 | 764.5 | 1961.2 __getitem__ (torch.float16, torch.int64) | 342.5 | 330.7 | 434.7 | 741.0 | 1969.4 __getitem__ (torch.float32, torch.int64) | 332.2 | 326.1 | 445.4 | 751.3 | 1972.6 Times are in microseconds (us). ``` While migrating also fixed index_put_accumulate for boolean types, by using compare_and_exchange trick over uint Fixes #153560 Pull Request resolved: #158064 Approved by: https://github.com/dcci (cherry picked from commit beed033) Co-authored-by: Nikita Shulga <nikita.shulga@gmail.com>
[MPS] Fix `index_kernel` for large tensors (pytorch#158064) Move `MetalShaderLibrary::bind_tensors` private method to OperatorUtils.h and extract `iter_tensor_offset` method, that returns an offset from the start of the storage associated with given tensor inside the iterator Migrated `index`, `index_put[_accumulate][_serial]` to the new paradigm that does not require additional tensor for indices nor special handling for 32 vs 64-bit offset, which resulted in almost 2x perf gain for 2000x2000 tensor, see results below before ``` [------------------------------------------------------------ -----------------------------------------------------------] | 11x50x50 | 11x100x100 | 11x500x500 | 11x1000x1000 | 11x2000x2000 1 threads: ---------------------------------------------------------------------------------------------------------------- __getitem__ (torch.int8, torch.int64) | 383.5 | 379.8 | 470.9 | 1232.9 | 4410.3 __getitem__ (torch.float16, torch.int64) | 379.6 | 354.5 | 533.2 | 1290.3 | 4442.2 __getitem__ (torch.float32, torch.int64) | 360.8 | 338.6 | 478.6 | 1348.9 | 4870.4 Times are in microseconds (us). ``` and after ``` [------------------------------------------------------------ -----------------------------------------------------------] | 11x50x50 | 11x100x100 | 11x500x500 | 11x1000x1000 | 11x2000x2000 1 threads: ---------------------------------------------------------------------------------------------------------------- __getitem__ (torch.int8, torch.int64) | 349.8 | 330.5 | 432.6 | 764.5 | 1961.2 __getitem__ (torch.float16, torch.int64) | 342.5 | 330.7 | 434.7 | 741.0 | 1969.4 __getitem__ (torch.float32, torch.int64) | 332.2 | 326.1 | 445.4 | 751.3 | 1972.6 Times are in microseconds (us). ``` While migrating also fixed index_put_accumulate for boolean types, by using compare_and_exchange trick over uint Fixes pytorch#153560 Pull Request resolved: pytorch#158064 Approved by: https://github.com/dcci (cherry picked from commit beed033) Co-authored-by: Nikita Shulga <nikita.shulga@gmail.com>
Stack from ghstack (oldest at bottom):
index_kernelfor large tensors #158064Move
MetalShaderLibrary::bind_tensorsprivate method to OperatorUtils.h and extractiter_tensor_offsetmethod, that returns an offset from the start of the storage associated with given tensor inside the iteratorMigrated
index,index_put[_accumulate][_serial]to the new paradigm that does not require additional tensor for indices nor special handling for 32 vs 64-bit offset, which resulted in almost 2x perf gain for 2000x2000 tensor, see results below beforeand after
While migrating also fixed index_put_accumulate for boolean types, by using compare_and_exchange trick over uint
Fixes #153560