[ROCm] Improve reduction sum performance#2492
Merged
pruthvistony merged 1 commit intorelease/2.7from Aug 13, 2025
Merged
Conversation
* Use input vectorization for reduction_on_fastest_striding_dimension when dim0 >= 0
**Reproducer:**
```
import time
import torch
shapes = [
(5079670, 128)
]
dims = [
(1)
]
for i, shape in enumerate(shapes):
x = torch.randn(shape, device='cuda', dtype=torch.float)
for _ in range(10):
w = torch.sum(x, dims[i])
torch.cuda.synchronize()
print(w.size())
start_time = time.time()
for _ in range(50):
_ = torch.sum(x, dims[i])
torch.cuda.synchronize()
end_time = time.time()
mean_time = (end_time - start_time)/50
print(f"Avg time for shape {shape}: {mean_time * 1e6:.2f} us")
```
**Before (MI300X):**
Avg time for shape (5079670, 128): 1629.99 us
**After (MI300X)**
Avg time for shape (5079670, 128): 1008.59 us
cherry-pick of pytorch#160466
|
Jenkins build for 18528f45f38a0cb0eab9c869f1c367156a1d7122 commit finished as FAILURE |
pruthvistony
approved these changes
Aug 13, 2025
Collaborator
|
Please cherry-pick into all required branches. |
Collaborator
Author
|
! cherry-pick --onto release/2.8 |
1 similar comment
Collaborator
|
! cherry-pick --onto release/2.8 |
dhonnappa-amd
pushed a commit
that referenced
this pull request
Aug 13, 2025
* Use input vectorization for reduction_on_fastest_striding_dimension
when dim0 >= 0
**Reproducer:**
```
import time
import torch
shapes = [
(5079670, 128)
]
dims = [
(1)
]
for i, shape in enumerate(shapes):
x = torch.randn(shape, device='cuda', dtype=torch.float)
for _ in range(10):
w = torch.sum(x, dims[i])
torch.cuda.synchronize()
print(w.size())
start_time = time.time()
for _ in range(50):
_ = torch.sum(x, dims[i])
torch.cuda.synchronize()
end_time = time.time()
mean_time = (end_time - start_time)/50
print(f"Avg time for shape {shape}: {mean_time * 1e6:.2f} us")
```
**Before (MI300X):**
Avg time for shape (5079670, 128): 1629.99 us
**After (MI300X)**
Avg time for shape (5079670, 128): 1008.59 us
cherry-pick of pytorch#160466
Fixes SWDEV-546136
|
Created branch autogenerated/release/2.8_cherry-pick_pr-2492 and #2505 |
Collaborator
|
! cherry-pick --onto rocm7.1_internal_testing |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Reproducer:
Before (MI300X):
Avg time for shape (5079670, 128): 1629.99 us
After (MI300X)
Avg time for shape (5079670, 128): 1008.59 us
cherry-pick of pytorch#160466
Fixes SWDEV-546136
Cherry-picked to release/2.8 branch via #2505