Skip to content

[MPS] Fix index_kernel for large tensors#158064

Closed
malfet wants to merge 6 commits intogh/malfet/435/basefrom
gh/malfet/435/head
Closed

[MPS] Fix index_kernel for large tensors#158064
malfet wants to merge 6 commits intogh/malfet/435/basefrom
gh/malfet/435/head

Conversation

@malfet
Copy link
Contributor

@malfet malfet commented Jul 10, 2025

Stack from ghstack (oldest at bottom):

Move MetalShaderLibrary::bind_tensors private method to OperatorUtils.h and extract iter_tensor_offset method, that returns an offset from the start of the storage associated with given tensor inside the iterator

Migrated index, index_put[_accumulate][_serial] to the new paradigm that does not require additional tensor for indices nor special handling for 32 vs 64-bit offset, which resulted in almost 2x perf gain for 2000x2000 tensor, see results below before

[------------------------------------------------------------  -----------------------------------------------------------]
                                                |  11x50x50  |  11x100x100  |  11x500x500  |  11x1000x1000  |  11x2000x2000
1 threads: ----------------------------------------------------------------------------------------------------------------
      __getitem__ (torch.int8, torch.int64)     |   383.5    |    379.8     |    470.9     |     1232.9     |     4410.3
      __getitem__ (torch.float16, torch.int64)  |   379.6    |    354.5     |    533.2     |     1290.3     |     4442.2
      __getitem__ (torch.float32, torch.int64)  |   360.8    |    338.6     |    478.6     |     1348.9     |     4870.4

Times are in microseconds (us).

and after

[------------------------------------------------------------  -----------------------------------------------------------]
                                                |  11x50x50  |  11x100x100  |  11x500x500  |  11x1000x1000  |  11x2000x2000
1 threads: ----------------------------------------------------------------------------------------------------------------
      __getitem__ (torch.int8, torch.int64)     |   349.8    |    330.5     |    432.6     |     764.5      |     1961.2
      __getitem__ (torch.float16, torch.int64)  |   342.5    |    330.7     |    434.7     |     741.0      |     1969.4
      __getitem__ (torch.float32, torch.int64)  |   332.2    |    326.1     |    445.4     |     751.3      |     1972.6
    
Times are in microseconds (us).

While migrating also fixed index_put_accumulate for boolean types, by using compare_and_exchange trick over uint

Fixes #153560

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Jul 10, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/158064

Note: Links to docs will display an error until the docs builds have been completed.

⏳ 1 Pending, 1 Unrelated Failure

As of commit 3a193a4 with merge base dd93883 (image):

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/mps Run MPS tests (subset of trunk) release notes: mps Release notes category labels Jul 10, 2025
malfet added a commit that referenced this pull request Jul 10, 2025
ghstack-source-id: c6035fe
Pull Request resolved: #158064
@malfet malfet requested a review from dcci July 10, 2025 23:46
@malfet malfet added the topic: bug fixes topic category label Jul 10, 2025
@malfet malfet requested a review from kulinseth as a code owner July 11, 2025 06:37
Move `MetalShaderLibrary::bind_tensors` private method to OperatorUtils.h and extract `iter_tensor_offset` method, that returns an offset from the start of the storage associated with given tensor inside the iterator


Fixes #153560

[ghstack-poisoned]
Move `MetalShaderLibrary::bind_tensors` private method to OperatorUtils.h and extract `iter_tensor_offset` method, that returns an offset from the start of the storage associated with given tensor inside the iterator

Migrated `index`, `index_put[_accumulate][_serial]` to the new paradigm that does not require additional tensor for indices nor special handling for 32 vs 64-bit offset, which resulted in almost 2x perf gain for 2000x2000 tensor, see results below before
```
[------------------------------------------------------------  -----------------------------------------------------------]
                                                |  11x50x50  |  11x100x100  |  11x500x500  |  11x1000x1000  |  11x2000x2000
1 threads: ----------------------------------------------------------------------------------------------------------------
      __getitem__ (torch.int8, torch.int64)     |   383.5    |    379.8     |    470.9     |     1232.9     |     4410.3
      __getitem__ (torch.float16, torch.int64)  |   379.6    |    354.5     |    533.2     |     1290.3     |     4442.2
      __getitem__ (torch.float32, torch.int64)  |   360.8    |    338.6     |    478.6     |     1348.9     |     4870.4

Times are in microseconds (us).
```
and after
```
[------------------------------------------------------------  -----------------------------------------------------------]
                                                |  11x50x50  |  11x100x100  |  11x500x500  |  11x1000x1000  |  11x2000x2000
1 threads: ----------------------------------------------------------------------------------------------------------------
      __getitem__ (torch.int8, torch.int64)     |   349.8    |    330.5     |    432.6     |     764.5      |     1961.2
      __getitem__ (torch.float16, torch.int64)  |   342.5    |    330.7     |    434.7     |     741.0      |     1969.4
      __getitem__ (torch.float32, torch.int64)  |   332.2    |    326.1     |    445.4     |     751.3      |     1972.6
    
Times are in microseconds (us).
```
    
Fixes #153560

[ghstack-poisoned]
Move `MetalShaderLibrary::bind_tensors` private method to OperatorUtils.h and extract `iter_tensor_offset` method, that returns an offset from the start of the storage associated with given tensor inside the iterator

Migrated `index`, `index_put[_accumulate][_serial]` to the new paradigm that does not require additional tensor for indices nor special handling for 32 vs 64-bit offset, which resulted in almost 2x perf gain for 2000x2000 tensor, see results below before
```
[------------------------------------------------------------  -----------------------------------------------------------]
                                                |  11x50x50  |  11x100x100  |  11x500x500  |  11x1000x1000  |  11x2000x2000
1 threads: ----------------------------------------------------------------------------------------------------------------
      __getitem__ (torch.int8, torch.int64)     |   383.5    |    379.8     |    470.9     |     1232.9     |     4410.3
      __getitem__ (torch.float16, torch.int64)  |   379.6    |    354.5     |    533.2     |     1290.3     |     4442.2
      __getitem__ (torch.float32, torch.int64)  |   360.8    |    338.6     |    478.6     |     1348.9     |     4870.4

Times are in microseconds (us).
```
and after
```
[------------------------------------------------------------  -----------------------------------------------------------]
                                                |  11x50x50  |  11x100x100  |  11x500x500  |  11x1000x1000  |  11x2000x2000
1 threads: ----------------------------------------------------------------------------------------------------------------
      __getitem__ (torch.int8, torch.int64)     |   349.8    |    330.5     |    432.6     |     764.5      |     1961.2
      __getitem__ (torch.float16, torch.int64)  |   342.5    |    330.7     |    434.7     |     741.0      |     1969.4
      __getitem__ (torch.float32, torch.int64)  |   332.2    |    326.1     |    445.4     |     751.3      |     1972.6
    
Times are in microseconds (us).
```
    
Fixes #153560

[ghstack-poisoned]
Move `MetalShaderLibrary::bind_tensors` private method to OperatorUtils.h and extract `iter_tensor_offset` method, that returns an offset from the start of the storage associated with given tensor inside the iterator

Migrated `index`, `index_put[_accumulate][_serial]` to the new paradigm that does not require additional tensor for indices nor special handling for 32 vs 64-bit offset, which resulted in almost 2x perf gain for 2000x2000 tensor, see results below before
```
[------------------------------------------------------------  -----------------------------------------------------------]
                                                |  11x50x50  |  11x100x100  |  11x500x500  |  11x1000x1000  |  11x2000x2000
1 threads: ----------------------------------------------------------------------------------------------------------------
      __getitem__ (torch.int8, torch.int64)     |   383.5    |    379.8     |    470.9     |     1232.9     |     4410.3
      __getitem__ (torch.float16, torch.int64)  |   379.6    |    354.5     |    533.2     |     1290.3     |     4442.2
      __getitem__ (torch.float32, torch.int64)  |   360.8    |    338.6     |    478.6     |     1348.9     |     4870.4

Times are in microseconds (us).
```
and after
```
[------------------------------------------------------------  -----------------------------------------------------------]
                                                |  11x50x50  |  11x100x100  |  11x500x500  |  11x1000x1000  |  11x2000x2000
1 threads: ----------------------------------------------------------------------------------------------------------------
      __getitem__ (torch.int8, torch.int64)     |   349.8    |    330.5     |    432.6     |     764.5      |     1961.2
      __getitem__ (torch.float16, torch.int64)  |   342.5    |    330.7     |    434.7     |     741.0      |     1969.4
      __getitem__ (torch.float32, torch.int64)  |   332.2    |    326.1     |    445.4     |     751.3      |     1972.6
    
Times are in microseconds (us).
```

While migrating also fixed index_put_accumulate for boolean types, by using compare_and_exchange trick over uint

Fixes #153560

[ghstack-poisoned]
malfet added a commit that referenced this pull request Jul 11, 2025
Benchmark before
```
[------------------------------------------------------------  -----------------------------------------------------------]
                                                |  11x50x50  |  11x100x100  |  11x500x500  |  11x1000x1000  |  11x2000x2000
1 threads: ----------------------------------------------------------------------------------------------------------------
      __getitem__ (torch.int8, torch.int64)     |   383.5    |    379.8     |    470.9     |     1232.9     |     4410.3
      __getitem__ (torch.float16, torch.int64)  |   379.6    |    354.5     |    533.2     |     1290.3     |     4442.2
      __getitem__ (torch.float32, torch.int64)  |   360.8    |    338.6     |    478.6     |     1348.9     |     4870.4

Times are in microseconds (us).
```

After
```
[------------------------------------------------------------  -----------------------------------------------------------]
                                                |  11x50x50  |  11x100x100  |  11x500x500  |  11x1000x1000  |  11x2000x2000
1 threads: ----------------------------------------------------------------------------------------------------------------
      __getitem__ (torch.int8, torch.int64)     |   349.8    |    330.5     |    432.6     |     764.5      |     1961.2
      __getitem__ (torch.float16, torch.int64)  |   342.5    |    330.7     |    434.7     |     741.0      |     1969.4
      __getitem__ (torch.float32, torch.int64)  |   332.2    |    326.1     |    445.4     |     751.3      |     1972.6

Times are in microseconds (us).

```

ghstack-source-id: 6301e68
Pull Request resolved: #158064
@malfet
Copy link
Contributor Author

malfet commented Jul 11, 2025

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 11, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@malfet
Copy link
Contributor Author

malfet commented Jul 11, 2025

@pytorchbot merge -f "Lint + MPS are green"

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Jul 12, 2025
Namely `index_get_offsets`, giving thread index computes offsets into
input, output and indices tensors
And `index_apply_indices` applies offests to either input or output
tensor index
Pull Request resolved: #158178
Approved by: https://github.com/dcci, https://github.com/Skylion007
ghstack dependencies: #158064
pytorchmergebot pushed a commit that referenced this pull request Jul 14, 2025
That fixes `index_put(..., accumulate=True)` for all dtypes

int64 operation is not really atomic, but eventually consistent from the `index_put_accumulate` kernel point of view: i.e. by the end of the operation results in the global memory are indeed accumulation of the operands at given indices
Pull Request resolved: #158179
Approved by: https://github.com/dcci, https://github.com/Skylion007
ghstack dependencies: #158064, #158178
@Camyll
Copy link
Contributor

Camyll commented Jul 14, 2025

@pytorchbot cherry-pick --onto release/2.8 -c critical

pytorchbot pushed a commit that referenced this pull request Jul 14, 2025
Move `MetalShaderLibrary::bind_tensors` private method to OperatorUtils.h and extract `iter_tensor_offset` method, that returns an offset from the start of the storage associated with given tensor inside the iterator

Migrated `index`, `index_put[_accumulate][_serial]` to the new paradigm that does not require additional tensor for indices nor special handling for 32 vs 64-bit offset, which resulted in almost 2x perf gain for 2000x2000 tensor, see results below before
```
[------------------------------------------------------------  -----------------------------------------------------------]
                                                |  11x50x50  |  11x100x100  |  11x500x500  |  11x1000x1000  |  11x2000x2000
1 threads: ----------------------------------------------------------------------------------------------------------------
      __getitem__ (torch.int8, torch.int64)     |   383.5    |    379.8     |    470.9     |     1232.9     |     4410.3
      __getitem__ (torch.float16, torch.int64)  |   379.6    |    354.5     |    533.2     |     1290.3     |     4442.2
      __getitem__ (torch.float32, torch.int64)  |   360.8    |    338.6     |    478.6     |     1348.9     |     4870.4

Times are in microseconds (us).
```
and after
```
[------------------------------------------------------------  -----------------------------------------------------------]
                                                |  11x50x50  |  11x100x100  |  11x500x500  |  11x1000x1000  |  11x2000x2000
1 threads: ----------------------------------------------------------------------------------------------------------------
      __getitem__ (torch.int8, torch.int64)     |   349.8    |    330.5     |    432.6     |     764.5      |     1961.2
      __getitem__ (torch.float16, torch.int64)  |   342.5    |    330.7     |    434.7     |     741.0      |     1969.4
      __getitem__ (torch.float32, torch.int64)  |   332.2    |    326.1     |    445.4     |     751.3      |     1972.6

Times are in microseconds (us).
```

While migrating also fixed index_put_accumulate for boolean types, by using compare_and_exchange trick over uint

Fixes #153560
Pull Request resolved: #158064
Approved by: https://github.com/dcci

(cherry picked from commit beed033)
@pytorchbot
Copy link
Collaborator

Cherry picking #158064

The cherry pick PR is at #158239 and it is recommended to link a critical cherry pick PR with an issue. The following tracker issues are updated:

Details for Dev Infra team Raised by workflow job

atalman pushed a commit that referenced this pull request Jul 16, 2025
[MPS] Fix `index_kernel` for large tensors (#158064)

Move `MetalShaderLibrary::bind_tensors` private method to OperatorUtils.h and extract `iter_tensor_offset` method, that returns an offset from the start of the storage associated with given tensor inside the iterator

Migrated `index`, `index_put[_accumulate][_serial]` to the new paradigm that does not require additional tensor for indices nor special handling for 32 vs 64-bit offset, which resulted in almost 2x perf gain for 2000x2000 tensor, see results below before
```
[------------------------------------------------------------  -----------------------------------------------------------]
                                                |  11x50x50  |  11x100x100  |  11x500x500  |  11x1000x1000  |  11x2000x2000
1 threads: ----------------------------------------------------------------------------------------------------------------
      __getitem__ (torch.int8, torch.int64)     |   383.5    |    379.8     |    470.9     |     1232.9     |     4410.3
      __getitem__ (torch.float16, torch.int64)  |   379.6    |    354.5     |    533.2     |     1290.3     |     4442.2
      __getitem__ (torch.float32, torch.int64)  |   360.8    |    338.6     |    478.6     |     1348.9     |     4870.4

Times are in microseconds (us).
```
and after
```
[------------------------------------------------------------  -----------------------------------------------------------]
                                                |  11x50x50  |  11x100x100  |  11x500x500  |  11x1000x1000  |  11x2000x2000
1 threads: ----------------------------------------------------------------------------------------------------------------
      __getitem__ (torch.int8, torch.int64)     |   349.8    |    330.5     |    432.6     |     764.5      |     1961.2
      __getitem__ (torch.float16, torch.int64)  |   342.5    |    330.7     |    434.7     |     741.0      |     1969.4
      __getitem__ (torch.float32, torch.int64)  |   332.2    |    326.1     |    445.4     |     751.3      |     1972.6

Times are in microseconds (us).
```

While migrating also fixed index_put_accumulate for boolean types, by using compare_and_exchange trick over uint

Fixes #153560
Pull Request resolved: #158064
Approved by: https://github.com/dcci

(cherry picked from commit beed033)

Co-authored-by: Nikita Shulga <nikita.shulga@gmail.com>
@github-actions github-actions bot deleted the gh/malfet/435/head branch August 14, 2025 02:19
tvukovic-amd pushed a commit to ROCm/pytorch that referenced this pull request Aug 20, 2025
[MPS] Fix `index_kernel` for large tensors (pytorch#158064)

Move `MetalShaderLibrary::bind_tensors` private method to OperatorUtils.h and extract `iter_tensor_offset` method, that returns an offset from the start of the storage associated with given tensor inside the iterator

Migrated `index`, `index_put[_accumulate][_serial]` to the new paradigm that does not require additional tensor for indices nor special handling for 32 vs 64-bit offset, which resulted in almost 2x perf gain for 2000x2000 tensor, see results below before
```
[------------------------------------------------------------  -----------------------------------------------------------]
                                                |  11x50x50  |  11x100x100  |  11x500x500  |  11x1000x1000  |  11x2000x2000
1 threads: ----------------------------------------------------------------------------------------------------------------
      __getitem__ (torch.int8, torch.int64)     |   383.5    |    379.8     |    470.9     |     1232.9     |     4410.3
      __getitem__ (torch.float16, torch.int64)  |   379.6    |    354.5     |    533.2     |     1290.3     |     4442.2
      __getitem__ (torch.float32, torch.int64)  |   360.8    |    338.6     |    478.6     |     1348.9     |     4870.4

Times are in microseconds (us).
```
and after
```
[------------------------------------------------------------  -----------------------------------------------------------]
                                                |  11x50x50  |  11x100x100  |  11x500x500  |  11x1000x1000  |  11x2000x2000
1 threads: ----------------------------------------------------------------------------------------------------------------
      __getitem__ (torch.int8, torch.int64)     |   349.8    |    330.5     |    432.6     |     764.5      |     1961.2
      __getitem__ (torch.float16, torch.int64)  |   342.5    |    330.7     |    434.7     |     741.0      |     1969.4
      __getitem__ (torch.float32, torch.int64)  |   332.2    |    326.1     |    445.4     |     751.3      |     1972.6

Times are in microseconds (us).
```

While migrating also fixed index_put_accumulate for boolean types, by using compare_and_exchange trick over uint

Fixes pytorch#153560
Pull Request resolved: pytorch#158064
Approved by: https://github.com/dcci

(cherry picked from commit beed033)

Co-authored-by: Nikita Shulga <nikita.shulga@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/mps Run MPS tests (subset of trunk) ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: mps Release notes category topic: bug fixes topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants