Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/161241
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (2 Unrelated Failures)As of commit e05eae3 with merge base a99d8d3 ( BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
This pull request was exported from Phabricator. Differential Revision: D80771648 |
|
@exclamaforte has imported this pull request. If you are a Meta employee, you can view this in D80771648. |
4a42dca to
1bb3352
Compare
|
@exclamaforte has imported this pull request. If you are a Meta employee, you can view this in D80771648. |
7c850eb to
911a8b0
Compare
eellison
left a comment
There was a problem hiding this comment.
looks good ! cc @PaulZhang12
d79d6ff to
6fd9b1b
Compare
|
@exclamaforte can you do this in the same way we did #161026 ? this will allow us to integrate with everything else nicely
|
|
@pytorchbot merge |
|
@exclamaforte has imported this pull request. If you are a Meta employee, you can view this in D80771648. |
Merge failedReason: This PR has internal changes and must be landed via Phabricator! Please try reimporting/rexporting the PR! Details for Dev Infra teamRaised by workflow job |
|
@exclamaforte your PR has been successfully reverted. |
|
@exclamaforte has imported this pull request. If you are a Meta employee, you can view this in D80771648. |
|
Verified that I fixed this test now on an AMD devserver |
|
@pytorchbot merge |
Merge failedReason: This PR has internal changes and must be landed via Phabricator! Please try reimporting/rexporting the PR! Details for Dev Infra teamRaised by workflow job |
|
@exclamaforte has imported this pull request. If you are a Meta employee, you can view this in D80771648. |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
## Summary
Adds a subgraph decomposition for addmm and mm that performs well on large `K` compared to `M` and `N`, and functions well as an alternative to `split-k` on AMD (transposed only), which does not support AMD currently.
## Background
On AMD (MI300x), for a matmul A * B, if B is non-contiguous, the resulting matmul is quite a bit slower.
For example:
```
args[0]: TensorBox(StorageBox(
InputBuffer(name='arg0_1', layout=FixedLayout('cuda:0', torch.float16, size=[1024, 178176], stride=[178176, 1]))
))
args[1]: TensorBox(StorageBox(
InputBuffer(name='arg1_1', layout=FixedLayout('cuda:0', torch.float16, size=[178176, 6144], stride=[1, 178176]))
))
```
is a lot slower than:
```
args[0]: TensorBox(StorageBox(
InputBuffer(name='arg0_1', layout=FixedLayout('cuda:0', torch.float16, size=[1024, 178176], stride=[178176, 1]))
))
args[1]: TensorBox(StorageBox(
InputBuffer(name='arg1_1', layout=FixedLayout('cuda:0', torch.float16, size=[178176, 6144], stride=[6144, 1]))
))
```
This PR adds a subgraph decomposition to test out whether making B contiguous is faster than just using the normal kernels.
## Data
I ran this on unique non-contiguous shapes from torchbench/huggingface and got these speedups:
```
Parsed 420 unique shapes from benchmark output
addmm improvements when best:
addmm_16448x512x2048: +0.14%
addmm_128x2048x2048: +0.01%
addmm_128x768x1000: +0.75%
addmm_12672x3072x768: +1.08%
addmm_512x768x32000: +0.62%
addmm_12608x384x384: +0.00%
addmm_4160x1024x4096: +0.90%
addmm_16x768x2: +0.56%
addmm_12608x3072x768: +0.09%
addmm_64x4096x1000: +2.77%
addmm_256x1024x512: +1.99%
addmm_30x256x256: +1.12%
addmm_100480x128x384: +0.91%
addmm_6400x2048x512: +0.25%
addmm_61568x1024x256: +0.08%
addmm_1x768x768: +0.93%
addmm_12544x384x384: +0.19%
addmm_128x512x1000: +0.77%
addmm_2048x128x128: +1.32%
addmm_128x3072x1000: +0.24%
addmm_7936x512x2048: +0.07%
addmm_8192x512x2048: +0.33%
addmm_64x1024x1000: +1.43%
addmm_128x2304x1000: +0.01%
addmm_32768x256x2: +0.75%
addmm_64x384x1152: +0.79%
addmm_64x640x1000: +0.01%
addmm_100480x128x128: +0.87%
addmm_1152x3072x768: +1.13%
addmm_8192x256x2048: +1.40%
addmm_4096x128x768: +0.01%
addmm_128x2560x1000: +0.01%
addmm_12544x2048x512: +0.43%
addmm_200704x24x96: +0.14%
addmm_8448x512x2048: +0.96%
addmm_50176x256x1024: +0.62%
addmm_4160x4096x1024: +0.22%
addmm_4096x768x768: +0.32%
addmm_220x2048x512: +0.56%
addmm_8x2048x1000: +1.12%
addmm_256x197951x512: +26.99%
addmm_401536x64x192: +0.60%
addmm_2040x2048x512: +0.47%
addmm_512x1024x256: +1.32%
addmm_128x4096x1000: +1.67%
addmm_12672x768x768: +0.34%
addmm_128x368x1000: +0.77%
addmm_96x1280x1000: +0.01%
addmm_12544x512x2048: +0.41%
addmm_6272x320x1280: +0.76%
addmm_12544x3072x768: +0.09%
addmm_64x384x1000: +0.39%
mm improvements when best:
mm_200704x128x512: +1.29%
mm_663552x16x16: +0.80%
mm_4096x768x768: +0.51%
mm_131072x64x31: +0.24%
mm_12544x1152x384: +0.11%
mm_128x2048x2: +0.46%
mm_262144x16x23: +0.62%
mm_50176x576x192: +0.37%
mm_131072x16x31: +0.26%
================================================================================
BENCHMARK ANALYSIS RESULTS
================================================================================
Operation: addmm
----------------------------------------
Total shapes analyzed: 247
Average Subgraph placement: 3.38
Median Subgraph placement: 2.0
Subgraph is best choice: 52/247 shapes (21.1%)
Average improvement when best: 1.15%
Median improvement when best: 0.58%
Largest improvement when best: +26.99%
Operation: bmm
----------------------------------------
Total shapes analyzed: 85
Average Subgraph placement: 24.00
Median Subgraph placement: 21.0
Subgraph is best choice: 0/85 shapes (0.0%)
Average improvement when best: N/A (never best)
Median improvement when best: N/A (never best)
Largest improvement when best: N/A (never best)
Operation: mm
----------------------------------------
Total shapes analyzed: 88
Average Subgraph placement: 15.08
Median Subgraph placement: 4.0
Subgraph is best choice: 9/88 shapes (10.2%)
Average improvement when best: 0.52%
Median improvement when best: 0.46%
Largest improvement when best: +1.29%
```
## Results
The largest shape gain, `256,197951,512`, seemed to be driven by a case where the extern kernel is way faster than the best triton configs on the recursive autotune:
```
addmm,Extern,extern_kernels.addmm,256,197951,512,0.38024500012397766
addmm,Triton,256,197951,512,32,256,16,2,2,4,2.005444049835205
addmm,Triton,256,197951,512,32,128,32,2,4,8,2.04189395904541
addmm,Triton,256,197951,512,64,128,16,2,4,8,2.1911399364471436
addmm,Triton,256,197951,512,64,128,32,2,4,8,2.496040105819702
addmm,Triton,256,197951,512,64,128,64,2,8,16,2.9306790828704834
addmm,Triton,256,197951,512,64,64,32,2,4,8,3.0347819328308105
...
```
Compared to the non-transposed autotune:
```
addmm,Subgraph,contiguous_addmm_1384,256,197951,512,0.5024129748344421
addmm,Extern,extern_kernels.addmm,256,197951,512,0.6881489753723145
addmm,Triton,256,197951,512,32,256,16,2,2,4,2.5115010738372803
addmm,Triton,256,197951,512,32,128,32,2,4,8,2.5167479515075684
addmm,Triton,256,197951,512,64,128,16,2,4,8,2.9507460594177246
addmm,Triton,256,197951,512,64,256,64,2,8,4,2.9673290252685547
addmm,Triton,256,197951,512,64,128,64,2,8,16,3.3906331062316895
addmm,Triton,256,197951,512,64,128,32,2,4,8,3.496859073638916
```
It seems to perform really well for high values of `K` vs `N` and `M`.
Testing this hypothesis with some custom shapes:
```
Parsed 64 unique shapes from benchmark output
addmm improvements when best:
addmm_128x16384x128: +0.18%
addmm_128x262144x256: +38.24%
addmm_128x200000x512: +14.76%
addmm_256x800000x128: +0.06%
addmm_131072x128x256: +0.27%
addmm_128x256x131072: +0.25%
addmm_2048x200000x64: +12.45%
mm improvements when best:
mm_128x16384x128: +0.18%
mm_128x262144x256: +38.05%
mm_128x200000x512: +9.47%
mm_256x800000x128: +0.99%
mm_512x6400000x256: +3.17%
mm_524288x64x64: +0.29%
mm_2048x200000x64: +11.19%
mm_8192x1000000x256: +34.14%
mm_128x4096x100000: +0.40%
mm_128x3072x150000: +0.27%
================================================================================
BENCHMARK ANALYSIS RESULTS
================================================================================
Operation: addmm
----------------------------------------
Total shapes analyzed: 33
Average Subgraph placement: 4.39
Median Subgraph placement: 2.0
Subgraph is best choice: 7/33 shapes (21.2%)
Average improvement when best: 9.46%
Median improvement when best: 0.27%
Largest improvement when best: +38.24%
Operation: mm
----------------------------------------
Total shapes analyzed: 30
Average Subgraph placement: 7.63
Median Subgraph placement: 2.0
Subgraph is best choice: 10/30 shapes (33.3%)
Average improvement when best: 9.81%
Median improvement when best: 2.08%
Largest improvement when best: +38.05%
```
## Conclusion
Contiguous Subgraph Decompositionseems worthwhile for `mm` and `addmm`, but not `bmm`, and has a very large improvment on low `M`, low `N`, and high `K` shapes.
Data gathering scripts:
https://gist.github.com/exclamaforte/4a896c064d301b27bf5ca0a4f8fc3866
## Test Plan:
New unit tests.
Differential Revision: D80771648
Pull Request resolved: pytorch#161241
Approved by: https://github.com/eellison
This reverts commit d647185. Reverted pytorch#161241 on behalf of https://github.com/jeffdaily due to breaks rocm mi300 tests ([comment](pytorch#161241 (comment)))
## Summary
Adds a subgraph decomposition for addmm and mm that performs well on large `K` compared to `M` and `N`, and functions well as an alternative to `split-k` on AMD (transposed only), which does not support AMD currently.
## Background
On AMD (MI300x), for a matmul A * B, if B is non-contiguous, the resulting matmul is quite a bit slower.
For example:
```
args[0]: TensorBox(StorageBox(
InputBuffer(name='arg0_1', layout=FixedLayout('cuda:0', torch.float16, size=[1024, 178176], stride=[178176, 1]))
))
args[1]: TensorBox(StorageBox(
InputBuffer(name='arg1_1', layout=FixedLayout('cuda:0', torch.float16, size=[178176, 6144], stride=[1, 178176]))
))
```
is a lot slower than:
```
args[0]: TensorBox(StorageBox(
InputBuffer(name='arg0_1', layout=FixedLayout('cuda:0', torch.float16, size=[1024, 178176], stride=[178176, 1]))
))
args[1]: TensorBox(StorageBox(
InputBuffer(name='arg1_1', layout=FixedLayout('cuda:0', torch.float16, size=[178176, 6144], stride=[6144, 1]))
))
```
This PR adds a subgraph decomposition to test out whether making B contiguous is faster than just using the normal kernels.
## Data
I ran this on unique non-contiguous shapes from torchbench/huggingface and got these speedups:
```
Parsed 420 unique shapes from benchmark output
addmm improvements when best:
addmm_16448x512x2048: +0.14%
addmm_128x2048x2048: +0.01%
addmm_128x768x1000: +0.75%
addmm_12672x3072x768: +1.08%
addmm_512x768x32000: +0.62%
addmm_12608x384x384: +0.00%
addmm_4160x1024x4096: +0.90%
addmm_16x768x2: +0.56%
addmm_12608x3072x768: +0.09%
addmm_64x4096x1000: +2.77%
addmm_256x1024x512: +1.99%
addmm_30x256x256: +1.12%
addmm_100480x128x384: +0.91%
addmm_6400x2048x512: +0.25%
addmm_61568x1024x256: +0.08%
addmm_1x768x768: +0.93%
addmm_12544x384x384: +0.19%
addmm_128x512x1000: +0.77%
addmm_2048x128x128: +1.32%
addmm_128x3072x1000: +0.24%
addmm_7936x512x2048: +0.07%
addmm_8192x512x2048: +0.33%
addmm_64x1024x1000: +1.43%
addmm_128x2304x1000: +0.01%
addmm_32768x256x2: +0.75%
addmm_64x384x1152: +0.79%
addmm_64x640x1000: +0.01%
addmm_100480x128x128: +0.87%
addmm_1152x3072x768: +1.13%
addmm_8192x256x2048: +1.40%
addmm_4096x128x768: +0.01%
addmm_128x2560x1000: +0.01%
addmm_12544x2048x512: +0.43%
addmm_200704x24x96: +0.14%
addmm_8448x512x2048: +0.96%
addmm_50176x256x1024: +0.62%
addmm_4160x4096x1024: +0.22%
addmm_4096x768x768: +0.32%
addmm_220x2048x512: +0.56%
addmm_8x2048x1000: +1.12%
addmm_256x197951x512: +26.99%
addmm_401536x64x192: +0.60%
addmm_2040x2048x512: +0.47%
addmm_512x1024x256: +1.32%
addmm_128x4096x1000: +1.67%
addmm_12672x768x768: +0.34%
addmm_128x368x1000: +0.77%
addmm_96x1280x1000: +0.01%
addmm_12544x512x2048: +0.41%
addmm_6272x320x1280: +0.76%
addmm_12544x3072x768: +0.09%
addmm_64x384x1000: +0.39%
mm improvements when best:
mm_200704x128x512: +1.29%
mm_663552x16x16: +0.80%
mm_4096x768x768: +0.51%
mm_131072x64x31: +0.24%
mm_12544x1152x384: +0.11%
mm_128x2048x2: +0.46%
mm_262144x16x23: +0.62%
mm_50176x576x192: +0.37%
mm_131072x16x31: +0.26%
================================================================================
BENCHMARK ANALYSIS RESULTS
================================================================================
Operation: addmm
----------------------------------------
Total shapes analyzed: 247
Average Subgraph placement: 3.38
Median Subgraph placement: 2.0
Subgraph is best choice: 52/247 shapes (21.1%)
Average improvement when best: 1.15%
Median improvement when best: 0.58%
Largest improvement when best: +26.99%
Operation: bmm
----------------------------------------
Total shapes analyzed: 85
Average Subgraph placement: 24.00
Median Subgraph placement: 21.0
Subgraph is best choice: 0/85 shapes (0.0%)
Average improvement when best: N/A (never best)
Median improvement when best: N/A (never best)
Largest improvement when best: N/A (never best)
Operation: mm
----------------------------------------
Total shapes analyzed: 88
Average Subgraph placement: 15.08
Median Subgraph placement: 4.0
Subgraph is best choice: 9/88 shapes (10.2%)
Average improvement when best: 0.52%
Median improvement when best: 0.46%
Largest improvement when best: +1.29%
```
## Results
The largest shape gain, `256,197951,512`, seemed to be driven by a case where the extern kernel is way faster than the best triton configs on the recursive autotune:
```
addmm,Extern,extern_kernels.addmm,256,197951,512,0.38024500012397766
addmm,Triton,256,197951,512,32,256,16,2,2,4,2.005444049835205
addmm,Triton,256,197951,512,32,128,32,2,4,8,2.04189395904541
addmm,Triton,256,197951,512,64,128,16,2,4,8,2.1911399364471436
addmm,Triton,256,197951,512,64,128,32,2,4,8,2.496040105819702
addmm,Triton,256,197951,512,64,128,64,2,8,16,2.9306790828704834
addmm,Triton,256,197951,512,64,64,32,2,4,8,3.0347819328308105
...
```
Compared to the non-transposed autotune:
```
addmm,Subgraph,contiguous_addmm_1384,256,197951,512,0.5024129748344421
addmm,Extern,extern_kernels.addmm,256,197951,512,0.6881489753723145
addmm,Triton,256,197951,512,32,256,16,2,2,4,2.5115010738372803
addmm,Triton,256,197951,512,32,128,32,2,4,8,2.5167479515075684
addmm,Triton,256,197951,512,64,128,16,2,4,8,2.9507460594177246
addmm,Triton,256,197951,512,64,256,64,2,8,4,2.9673290252685547
addmm,Triton,256,197951,512,64,128,64,2,8,16,3.3906331062316895
addmm,Triton,256,197951,512,64,128,32,2,4,8,3.496859073638916
```
It seems to perform really well for high values of `K` vs `N` and `M`.
Testing this hypothesis with some custom shapes:
```
Parsed 64 unique shapes from benchmark output
addmm improvements when best:
addmm_128x16384x128: +0.18%
addmm_128x262144x256: +38.24%
addmm_128x200000x512: +14.76%
addmm_256x800000x128: +0.06%
addmm_131072x128x256: +0.27%
addmm_128x256x131072: +0.25%
addmm_2048x200000x64: +12.45%
mm improvements when best:
mm_128x16384x128: +0.18%
mm_128x262144x256: +38.05%
mm_128x200000x512: +9.47%
mm_256x800000x128: +0.99%
mm_512x6400000x256: +3.17%
mm_524288x64x64: +0.29%
mm_2048x200000x64: +11.19%
mm_8192x1000000x256: +34.14%
mm_128x4096x100000: +0.40%
mm_128x3072x150000: +0.27%
================================================================================
BENCHMARK ANALYSIS RESULTS
================================================================================
Operation: addmm
----------------------------------------
Total shapes analyzed: 33
Average Subgraph placement: 4.39
Median Subgraph placement: 2.0
Subgraph is best choice: 7/33 shapes (21.2%)
Average improvement when best: 9.46%
Median improvement when best: 0.27%
Largest improvement when best: +38.24%
Operation: mm
----------------------------------------
Total shapes analyzed: 30
Average Subgraph placement: 7.63
Median Subgraph placement: 2.0
Subgraph is best choice: 10/30 shapes (33.3%)
Average improvement when best: 9.81%
Median improvement when best: 2.08%
Largest improvement when best: +38.05%
```
## Conclusion
Contiguous Subgraph Decompositionseems worthwhile for `mm` and `addmm`, but not `bmm`, and has a very large improvment on low `M`, low `N`, and high `K` shapes.
Data gathering scripts:
https://gist.github.com/exclamaforte/4a896c064d301b27bf5ca0a4f8fc3866
## Test Plan:
New unit tests.
Differential Revision: D80771648
Pull Request resolved: pytorch#161241
Approved by: https://github.com/eellison
## Summary
Adds a subgraph decomposition for addmm and mm that performs well on large `K` compared to `M` and `N`, and functions well as an alternative to `split-k` on AMD (transposed only), which does not support AMD currently.
## Background
On AMD (MI300x), for a matmul A * B, if B is non-contiguous, the resulting matmul is quite a bit slower.
For example:
```
args[0]: TensorBox(StorageBox(
InputBuffer(name='arg0_1', layout=FixedLayout('cuda:0', torch.float16, size=[1024, 178176], stride=[178176, 1]))
))
args[1]: TensorBox(StorageBox(
InputBuffer(name='arg1_1', layout=FixedLayout('cuda:0', torch.float16, size=[178176, 6144], stride=[1, 178176]))
))
```
is a lot slower than:
```
args[0]: TensorBox(StorageBox(
InputBuffer(name='arg0_1', layout=FixedLayout('cuda:0', torch.float16, size=[1024, 178176], stride=[178176, 1]))
))
args[1]: TensorBox(StorageBox(
InputBuffer(name='arg1_1', layout=FixedLayout('cuda:0', torch.float16, size=[178176, 6144], stride=[6144, 1]))
))
```
This PR adds a subgraph decomposition to test out whether making B contiguous is faster than just using the normal kernels.
## Data
I ran this on unique non-contiguous shapes from torchbench/huggingface and got these speedups:
```
Parsed 420 unique shapes from benchmark output
addmm improvements when best:
addmm_16448x512x2048: +0.14%
addmm_128x2048x2048: +0.01%
addmm_128x768x1000: +0.75%
addmm_12672x3072x768: +1.08%
addmm_512x768x32000: +0.62%
addmm_12608x384x384: +0.00%
addmm_4160x1024x4096: +0.90%
addmm_16x768x2: +0.56%
addmm_12608x3072x768: +0.09%
addmm_64x4096x1000: +2.77%
addmm_256x1024x512: +1.99%
addmm_30x256x256: +1.12%
addmm_100480x128x384: +0.91%
addmm_6400x2048x512: +0.25%
addmm_61568x1024x256: +0.08%
addmm_1x768x768: +0.93%
addmm_12544x384x384: +0.19%
addmm_128x512x1000: +0.77%
addmm_2048x128x128: +1.32%
addmm_128x3072x1000: +0.24%
addmm_7936x512x2048: +0.07%
addmm_8192x512x2048: +0.33%
addmm_64x1024x1000: +1.43%
addmm_128x2304x1000: +0.01%
addmm_32768x256x2: +0.75%
addmm_64x384x1152: +0.79%
addmm_64x640x1000: +0.01%
addmm_100480x128x128: +0.87%
addmm_1152x3072x768: +1.13%
addmm_8192x256x2048: +1.40%
addmm_4096x128x768: +0.01%
addmm_128x2560x1000: +0.01%
addmm_12544x2048x512: +0.43%
addmm_200704x24x96: +0.14%
addmm_8448x512x2048: +0.96%
addmm_50176x256x1024: +0.62%
addmm_4160x4096x1024: +0.22%
addmm_4096x768x768: +0.32%
addmm_220x2048x512: +0.56%
addmm_8x2048x1000: +1.12%
addmm_256x197951x512: +26.99%
addmm_401536x64x192: +0.60%
addmm_2040x2048x512: +0.47%
addmm_512x1024x256: +1.32%
addmm_128x4096x1000: +1.67%
addmm_12672x768x768: +0.34%
addmm_128x368x1000: +0.77%
addmm_96x1280x1000: +0.01%
addmm_12544x512x2048: +0.41%
addmm_6272x320x1280: +0.76%
addmm_12544x3072x768: +0.09%
addmm_64x384x1000: +0.39%
mm improvements when best:
mm_200704x128x512: +1.29%
mm_663552x16x16: +0.80%
mm_4096x768x768: +0.51%
mm_131072x64x31: +0.24%
mm_12544x1152x384: +0.11%
mm_128x2048x2: +0.46%
mm_262144x16x23: +0.62%
mm_50176x576x192: +0.37%
mm_131072x16x31: +0.26%
================================================================================
BENCHMARK ANALYSIS RESULTS
================================================================================
Operation: addmm
----------------------------------------
Total shapes analyzed: 247
Average Subgraph placement: 3.38
Median Subgraph placement: 2.0
Subgraph is best choice: 52/247 shapes (21.1%)
Average improvement when best: 1.15%
Median improvement when best: 0.58%
Largest improvement when best: +26.99%
Operation: bmm
----------------------------------------
Total shapes analyzed: 85
Average Subgraph placement: 24.00
Median Subgraph placement: 21.0
Subgraph is best choice: 0/85 shapes (0.0%)
Average improvement when best: N/A (never best)
Median improvement when best: N/A (never best)
Largest improvement when best: N/A (never best)
Operation: mm
----------------------------------------
Total shapes analyzed: 88
Average Subgraph placement: 15.08
Median Subgraph placement: 4.0
Subgraph is best choice: 9/88 shapes (10.2%)
Average improvement when best: 0.52%
Median improvement when best: 0.46%
Largest improvement when best: +1.29%
```
## Results
The largest shape gain, `256,197951,512`, seemed to be driven by a case where the extern kernel is way faster than the best triton configs on the recursive autotune:
```
addmm,Extern,extern_kernels.addmm,256,197951,512,0.38024500012397766
addmm,Triton,256,197951,512,32,256,16,2,2,4,2.005444049835205
addmm,Triton,256,197951,512,32,128,32,2,4,8,2.04189395904541
addmm,Triton,256,197951,512,64,128,16,2,4,8,2.1911399364471436
addmm,Triton,256,197951,512,64,128,32,2,4,8,2.496040105819702
addmm,Triton,256,197951,512,64,128,64,2,8,16,2.9306790828704834
addmm,Triton,256,197951,512,64,64,32,2,4,8,3.0347819328308105
...
```
Compared to the non-transposed autotune:
```
addmm,Subgraph,contiguous_addmm_1384,256,197951,512,0.5024129748344421
addmm,Extern,extern_kernels.addmm,256,197951,512,0.6881489753723145
addmm,Triton,256,197951,512,32,256,16,2,2,4,2.5115010738372803
addmm,Triton,256,197951,512,32,128,32,2,4,8,2.5167479515075684
addmm,Triton,256,197951,512,64,128,16,2,4,8,2.9507460594177246
addmm,Triton,256,197951,512,64,256,64,2,8,4,2.9673290252685547
addmm,Triton,256,197951,512,64,128,64,2,8,16,3.3906331062316895
addmm,Triton,256,197951,512,64,128,32,2,4,8,3.496859073638916
```
It seems to perform really well for high values of `K` vs `N` and `M`.
Testing this hypothesis with some custom shapes:
```
Parsed 64 unique shapes from benchmark output
addmm improvements when best:
addmm_128x16384x128: +0.18%
addmm_128x262144x256: +38.24%
addmm_128x200000x512: +14.76%
addmm_256x800000x128: +0.06%
addmm_131072x128x256: +0.27%
addmm_128x256x131072: +0.25%
addmm_2048x200000x64: +12.45%
mm improvements when best:
mm_128x16384x128: +0.18%
mm_128x262144x256: +38.05%
mm_128x200000x512: +9.47%
mm_256x800000x128: +0.99%
mm_512x6400000x256: +3.17%
mm_524288x64x64: +0.29%
mm_2048x200000x64: +11.19%
mm_8192x1000000x256: +34.14%
mm_128x4096x100000: +0.40%
mm_128x3072x150000: +0.27%
================================================================================
BENCHMARK ANALYSIS RESULTS
================================================================================
Operation: addmm
----------------------------------------
Total shapes analyzed: 33
Average Subgraph placement: 4.39
Median Subgraph placement: 2.0
Subgraph is best choice: 7/33 shapes (21.2%)
Average improvement when best: 9.46%
Median improvement when best: 0.27%
Largest improvement when best: +38.24%
Operation: mm
----------------------------------------
Total shapes analyzed: 30
Average Subgraph placement: 7.63
Median Subgraph placement: 2.0
Subgraph is best choice: 10/30 shapes (33.3%)
Average improvement when best: 9.81%
Median improvement when best: 2.08%
Largest improvement when best: +38.05%
```
## Conclusion
Contiguous Subgraph Decompositionseems worthwhile for `mm` and `addmm`, but not `bmm`, and has a very large improvment on low `M`, low `N`, and high `K` shapes.
Data gathering scripts:
https://gist.github.com/exclamaforte/4a896c064d301b27bf5ca0a4f8fc3866
## Test Plan:
New unit tests.
Differential Revision: D80771648
Pull Request resolved: pytorch#161241
Approved by: https://github.com/eellison
This reverts commit d647185. Reverted pytorch#161241 on behalf of https://github.com/jeffdaily due to breaks rocm mi300 tests ([comment](pytorch#161241 (comment)))
## Summary
Adds a subgraph decomposition for addmm and mm that performs well on large `K` compared to `M` and `N`, and functions well as an alternative to `split-k` on AMD (transposed only), which does not support AMD currently.
## Background
On AMD (MI300x), for a matmul A * B, if B is non-contiguous, the resulting matmul is quite a bit slower.
For example:
```
args[0]: TensorBox(StorageBox(
InputBuffer(name='arg0_1', layout=FixedLayout('cuda:0', torch.float16, size=[1024, 178176], stride=[178176, 1]))
))
args[1]: TensorBox(StorageBox(
InputBuffer(name='arg1_1', layout=FixedLayout('cuda:0', torch.float16, size=[178176, 6144], stride=[1, 178176]))
))
```
is a lot slower than:
```
args[0]: TensorBox(StorageBox(
InputBuffer(name='arg0_1', layout=FixedLayout('cuda:0', torch.float16, size=[1024, 178176], stride=[178176, 1]))
))
args[1]: TensorBox(StorageBox(
InputBuffer(name='arg1_1', layout=FixedLayout('cuda:0', torch.float16, size=[178176, 6144], stride=[6144, 1]))
))
```
This PR adds a subgraph decomposition to test out whether making B contiguous is faster than just using the normal kernels.
## Data
I ran this on unique non-contiguous shapes from torchbench/huggingface and got these speedups:
```
Parsed 420 unique shapes from benchmark output
addmm improvements when best:
addmm_16448x512x2048: +0.14%
addmm_128x2048x2048: +0.01%
addmm_128x768x1000: +0.75%
addmm_12672x3072x768: +1.08%
addmm_512x768x32000: +0.62%
addmm_12608x384x384: +0.00%
addmm_4160x1024x4096: +0.90%
addmm_16x768x2: +0.56%
addmm_12608x3072x768: +0.09%
addmm_64x4096x1000: +2.77%
addmm_256x1024x512: +1.99%
addmm_30x256x256: +1.12%
addmm_100480x128x384: +0.91%
addmm_6400x2048x512: +0.25%
addmm_61568x1024x256: +0.08%
addmm_1x768x768: +0.93%
addmm_12544x384x384: +0.19%
addmm_128x512x1000: +0.77%
addmm_2048x128x128: +1.32%
addmm_128x3072x1000: +0.24%
addmm_7936x512x2048: +0.07%
addmm_8192x512x2048: +0.33%
addmm_64x1024x1000: +1.43%
addmm_128x2304x1000: +0.01%
addmm_32768x256x2: +0.75%
addmm_64x384x1152: +0.79%
addmm_64x640x1000: +0.01%
addmm_100480x128x128: +0.87%
addmm_1152x3072x768: +1.13%
addmm_8192x256x2048: +1.40%
addmm_4096x128x768: +0.01%
addmm_128x2560x1000: +0.01%
addmm_12544x2048x512: +0.43%
addmm_200704x24x96: +0.14%
addmm_8448x512x2048: +0.96%
addmm_50176x256x1024: +0.62%
addmm_4160x4096x1024: +0.22%
addmm_4096x768x768: +0.32%
addmm_220x2048x512: +0.56%
addmm_8x2048x1000: +1.12%
addmm_256x197951x512: +26.99%
addmm_401536x64x192: +0.60%
addmm_2040x2048x512: +0.47%
addmm_512x1024x256: +1.32%
addmm_128x4096x1000: +1.67%
addmm_12672x768x768: +0.34%
addmm_128x368x1000: +0.77%
addmm_96x1280x1000: +0.01%
addmm_12544x512x2048: +0.41%
addmm_6272x320x1280: +0.76%
addmm_12544x3072x768: +0.09%
addmm_64x384x1000: +0.39%
mm improvements when best:
mm_200704x128x512: +1.29%
mm_663552x16x16: +0.80%
mm_4096x768x768: +0.51%
mm_131072x64x31: +0.24%
mm_12544x1152x384: +0.11%
mm_128x2048x2: +0.46%
mm_262144x16x23: +0.62%
mm_50176x576x192: +0.37%
mm_131072x16x31: +0.26%
================================================================================
BENCHMARK ANALYSIS RESULTS
================================================================================
Operation: addmm
----------------------------------------
Total shapes analyzed: 247
Average Subgraph placement: 3.38
Median Subgraph placement: 2.0
Subgraph is best choice: 52/247 shapes (21.1%)
Average improvement when best: 1.15%
Median improvement when best: 0.58%
Largest improvement when best: +26.99%
Operation: bmm
----------------------------------------
Total shapes analyzed: 85
Average Subgraph placement: 24.00
Median Subgraph placement: 21.0
Subgraph is best choice: 0/85 shapes (0.0%)
Average improvement when best: N/A (never best)
Median improvement when best: N/A (never best)
Largest improvement when best: N/A (never best)
Operation: mm
----------------------------------------
Total shapes analyzed: 88
Average Subgraph placement: 15.08
Median Subgraph placement: 4.0
Subgraph is best choice: 9/88 shapes (10.2%)
Average improvement when best: 0.52%
Median improvement when best: 0.46%
Largest improvement when best: +1.29%
```
## Results
The largest shape gain, `256,197951,512`, seemed to be driven by a case where the extern kernel is way faster than the best triton configs on the recursive autotune:
```
addmm,Extern,extern_kernels.addmm,256,197951,512,0.38024500012397766
addmm,Triton,256,197951,512,32,256,16,2,2,4,2.005444049835205
addmm,Triton,256,197951,512,32,128,32,2,4,8,2.04189395904541
addmm,Triton,256,197951,512,64,128,16,2,4,8,2.1911399364471436
addmm,Triton,256,197951,512,64,128,32,2,4,8,2.496040105819702
addmm,Triton,256,197951,512,64,128,64,2,8,16,2.9306790828704834
addmm,Triton,256,197951,512,64,64,32,2,4,8,3.0347819328308105
...
```
Compared to the non-transposed autotune:
```
addmm,Subgraph,contiguous_addmm_1384,256,197951,512,0.5024129748344421
addmm,Extern,extern_kernels.addmm,256,197951,512,0.6881489753723145
addmm,Triton,256,197951,512,32,256,16,2,2,4,2.5115010738372803
addmm,Triton,256,197951,512,32,128,32,2,4,8,2.5167479515075684
addmm,Triton,256,197951,512,64,128,16,2,4,8,2.9507460594177246
addmm,Triton,256,197951,512,64,256,64,2,8,4,2.9673290252685547
addmm,Triton,256,197951,512,64,128,64,2,8,16,3.3906331062316895
addmm,Triton,256,197951,512,64,128,32,2,4,8,3.496859073638916
```
It seems to perform really well for high values of `K` vs `N` and `M`.
Testing this hypothesis with some custom shapes:
```
Parsed 64 unique shapes from benchmark output
addmm improvements when best:
addmm_128x16384x128: +0.18%
addmm_128x262144x256: +38.24%
addmm_128x200000x512: +14.76%
addmm_256x800000x128: +0.06%
addmm_131072x128x256: +0.27%
addmm_128x256x131072: +0.25%
addmm_2048x200000x64: +12.45%
mm improvements when best:
mm_128x16384x128: +0.18%
mm_128x262144x256: +38.05%
mm_128x200000x512: +9.47%
mm_256x800000x128: +0.99%
mm_512x6400000x256: +3.17%
mm_524288x64x64: +0.29%
mm_2048x200000x64: +11.19%
mm_8192x1000000x256: +34.14%
mm_128x4096x100000: +0.40%
mm_128x3072x150000: +0.27%
================================================================================
BENCHMARK ANALYSIS RESULTS
================================================================================
Operation: addmm
----------------------------------------
Total shapes analyzed: 33
Average Subgraph placement: 4.39
Median Subgraph placement: 2.0
Subgraph is best choice: 7/33 shapes (21.2%)
Average improvement when best: 9.46%
Median improvement when best: 0.27%
Largest improvement when best: +38.24%
Operation: mm
----------------------------------------
Total shapes analyzed: 30
Average Subgraph placement: 7.63
Median Subgraph placement: 2.0
Subgraph is best choice: 10/30 shapes (33.3%)
Average improvement when best: 9.81%
Median improvement when best: 2.08%
Largest improvement when best: +38.05%
```
## Conclusion
Contiguous Subgraph Decompositionseems worthwhile for `mm` and `addmm`, but not `bmm`, and has a very large improvment on low `M`, low `N`, and high `K` shapes.
Data gathering scripts:
https://gist.github.com/exclamaforte/4a896c064d301b27bf5ca0a4f8fc3866
## Test Plan:
New unit tests.
Differential Revision: D80771648
Pull Request resolved: pytorch#161241
Approved by: https://github.com/eellison
## Summary
Adds a subgraph decomposition for addmm and mm that performs well on large `K` compared to `M` and `N`, and functions well as an alternative to `split-k` on AMD (transposed only), which does not support AMD currently.
## Background
On AMD (MI300x), for a matmul A * B, if B is non-contiguous, the resulting matmul is quite a bit slower.
For example:
```
args[0]: TensorBox(StorageBox(
InputBuffer(name='arg0_1', layout=FixedLayout('cuda:0', torch.float16, size=[1024, 178176], stride=[178176, 1]))
))
args[1]: TensorBox(StorageBox(
InputBuffer(name='arg1_1', layout=FixedLayout('cuda:0', torch.float16, size=[178176, 6144], stride=[1, 178176]))
))
```
is a lot slower than:
```
args[0]: TensorBox(StorageBox(
InputBuffer(name='arg0_1', layout=FixedLayout('cuda:0', torch.float16, size=[1024, 178176], stride=[178176, 1]))
))
args[1]: TensorBox(StorageBox(
InputBuffer(name='arg1_1', layout=FixedLayout('cuda:0', torch.float16, size=[178176, 6144], stride=[6144, 1]))
))
```
This PR adds a subgraph decomposition to test out whether making B contiguous is faster than just using the normal kernels.
## Data
I ran this on unique non-contiguous shapes from torchbench/huggingface and got these speedups:
```
Parsed 420 unique shapes from benchmark output
addmm improvements when best:
addmm_16448x512x2048: +0.14%
addmm_128x2048x2048: +0.01%
addmm_128x768x1000: +0.75%
addmm_12672x3072x768: +1.08%
addmm_512x768x32000: +0.62%
addmm_12608x384x384: +0.00%
addmm_4160x1024x4096: +0.90%
addmm_16x768x2: +0.56%
addmm_12608x3072x768: +0.09%
addmm_64x4096x1000: +2.77%
addmm_256x1024x512: +1.99%
addmm_30x256x256: +1.12%
addmm_100480x128x384: +0.91%
addmm_6400x2048x512: +0.25%
addmm_61568x1024x256: +0.08%
addmm_1x768x768: +0.93%
addmm_12544x384x384: +0.19%
addmm_128x512x1000: +0.77%
addmm_2048x128x128: +1.32%
addmm_128x3072x1000: +0.24%
addmm_7936x512x2048: +0.07%
addmm_8192x512x2048: +0.33%
addmm_64x1024x1000: +1.43%
addmm_128x2304x1000: +0.01%
addmm_32768x256x2: +0.75%
addmm_64x384x1152: +0.79%
addmm_64x640x1000: +0.01%
addmm_100480x128x128: +0.87%
addmm_1152x3072x768: +1.13%
addmm_8192x256x2048: +1.40%
addmm_4096x128x768: +0.01%
addmm_128x2560x1000: +0.01%
addmm_12544x2048x512: +0.43%
addmm_200704x24x96: +0.14%
addmm_8448x512x2048: +0.96%
addmm_50176x256x1024: +0.62%
addmm_4160x4096x1024: +0.22%
addmm_4096x768x768: +0.32%
addmm_220x2048x512: +0.56%
addmm_8x2048x1000: +1.12%
addmm_256x197951x512: +26.99%
addmm_401536x64x192: +0.60%
addmm_2040x2048x512: +0.47%
addmm_512x1024x256: +1.32%
addmm_128x4096x1000: +1.67%
addmm_12672x768x768: +0.34%
addmm_128x368x1000: +0.77%
addmm_96x1280x1000: +0.01%
addmm_12544x512x2048: +0.41%
addmm_6272x320x1280: +0.76%
addmm_12544x3072x768: +0.09%
addmm_64x384x1000: +0.39%
mm improvements when best:
mm_200704x128x512: +1.29%
mm_663552x16x16: +0.80%
mm_4096x768x768: +0.51%
mm_131072x64x31: +0.24%
mm_12544x1152x384: +0.11%
mm_128x2048x2: +0.46%
mm_262144x16x23: +0.62%
mm_50176x576x192: +0.37%
mm_131072x16x31: +0.26%
================================================================================
BENCHMARK ANALYSIS RESULTS
================================================================================
Operation: addmm
----------------------------------------
Total shapes analyzed: 247
Average Subgraph placement: 3.38
Median Subgraph placement: 2.0
Subgraph is best choice: 52/247 shapes (21.1%)
Average improvement when best: 1.15%
Median improvement when best: 0.58%
Largest improvement when best: +26.99%
Operation: bmm
----------------------------------------
Total shapes analyzed: 85
Average Subgraph placement: 24.00
Median Subgraph placement: 21.0
Subgraph is best choice: 0/85 shapes (0.0%)
Average improvement when best: N/A (never best)
Median improvement when best: N/A (never best)
Largest improvement when best: N/A (never best)
Operation: mm
----------------------------------------
Total shapes analyzed: 88
Average Subgraph placement: 15.08
Median Subgraph placement: 4.0
Subgraph is best choice: 9/88 shapes (10.2%)
Average improvement when best: 0.52%
Median improvement when best: 0.46%
Largest improvement when best: +1.29%
```
## Results
The largest shape gain, `256,197951,512`, seemed to be driven by a case where the extern kernel is way faster than the best triton configs on the recursive autotune:
```
addmm,Extern,extern_kernels.addmm,256,197951,512,0.38024500012397766
addmm,Triton,256,197951,512,32,256,16,2,2,4,2.005444049835205
addmm,Triton,256,197951,512,32,128,32,2,4,8,2.04189395904541
addmm,Triton,256,197951,512,64,128,16,2,4,8,2.1911399364471436
addmm,Triton,256,197951,512,64,128,32,2,4,8,2.496040105819702
addmm,Triton,256,197951,512,64,128,64,2,8,16,2.9306790828704834
addmm,Triton,256,197951,512,64,64,32,2,4,8,3.0347819328308105
...
```
Compared to the non-transposed autotune:
```
addmm,Subgraph,contiguous_addmm_1384,256,197951,512,0.5024129748344421
addmm,Extern,extern_kernels.addmm,256,197951,512,0.6881489753723145
addmm,Triton,256,197951,512,32,256,16,2,2,4,2.5115010738372803
addmm,Triton,256,197951,512,32,128,32,2,4,8,2.5167479515075684
addmm,Triton,256,197951,512,64,128,16,2,4,8,2.9507460594177246
addmm,Triton,256,197951,512,64,256,64,2,8,4,2.9673290252685547
addmm,Triton,256,197951,512,64,128,64,2,8,16,3.3906331062316895
addmm,Triton,256,197951,512,64,128,32,2,4,8,3.496859073638916
```
It seems to perform really well for high values of `K` vs `N` and `M`.
Testing this hypothesis with some custom shapes:
```
Parsed 64 unique shapes from benchmark output
addmm improvements when best:
addmm_128x16384x128: +0.18%
addmm_128x262144x256: +38.24%
addmm_128x200000x512: +14.76%
addmm_256x800000x128: +0.06%
addmm_131072x128x256: +0.27%
addmm_128x256x131072: +0.25%
addmm_2048x200000x64: +12.45%
mm improvements when best:
mm_128x16384x128: +0.18%
mm_128x262144x256: +38.05%
mm_128x200000x512: +9.47%
mm_256x800000x128: +0.99%
mm_512x6400000x256: +3.17%
mm_524288x64x64: +0.29%
mm_2048x200000x64: +11.19%
mm_8192x1000000x256: +34.14%
mm_128x4096x100000: +0.40%
mm_128x3072x150000: +0.27%
================================================================================
BENCHMARK ANALYSIS RESULTS
================================================================================
Operation: addmm
----------------------------------------
Total shapes analyzed: 33
Average Subgraph placement: 4.39
Median Subgraph placement: 2.0
Subgraph is best choice: 7/33 shapes (21.2%)
Average improvement when best: 9.46%
Median improvement when best: 0.27%
Largest improvement when best: +38.24%
Operation: mm
----------------------------------------
Total shapes analyzed: 30
Average Subgraph placement: 7.63
Median Subgraph placement: 2.0
Subgraph is best choice: 10/30 shapes (33.3%)
Average improvement when best: 9.81%
Median improvement when best: 2.08%
Largest improvement when best: +38.05%
```
## Conclusion
Contiguous Subgraph Decompositionseems worthwhile for `mm` and `addmm`, but not `bmm`, and has a very large improvment on low `M`, low `N`, and high `K` shapes.
Data gathering scripts:
https://gist.github.com/exclamaforte/4a896c064d301b27bf5ca0a4f8fc3866
## Test Plan:
New unit tests.
Differential Revision: D80771648
Pull Request resolved: pytorch#161241
Approved by: https://github.com/eellison
This reverts commit d647185. Reverted pytorch#161241 on behalf of https://github.com/jeffdaily due to breaks rocm mi300 tests ([comment](pytorch#161241 (comment)))
## Summary
Adds a subgraph decomposition for addmm and mm that performs well on large `K` compared to `M` and `N`, and functions well as an alternative to `split-k` on AMD (transposed only), which does not support AMD currently.
## Background
On AMD (MI300x), for a matmul A * B, if B is non-contiguous, the resulting matmul is quite a bit slower.
For example:
```
args[0]: TensorBox(StorageBox(
InputBuffer(name='arg0_1', layout=FixedLayout('cuda:0', torch.float16, size=[1024, 178176], stride=[178176, 1]))
))
args[1]: TensorBox(StorageBox(
InputBuffer(name='arg1_1', layout=FixedLayout('cuda:0', torch.float16, size=[178176, 6144], stride=[1, 178176]))
))
```
is a lot slower than:
```
args[0]: TensorBox(StorageBox(
InputBuffer(name='arg0_1', layout=FixedLayout('cuda:0', torch.float16, size=[1024, 178176], stride=[178176, 1]))
))
args[1]: TensorBox(StorageBox(
InputBuffer(name='arg1_1', layout=FixedLayout('cuda:0', torch.float16, size=[178176, 6144], stride=[6144, 1]))
))
```
This PR adds a subgraph decomposition to test out whether making B contiguous is faster than just using the normal kernels.
## Data
I ran this on unique non-contiguous shapes from torchbench/huggingface and got these speedups:
```
Parsed 420 unique shapes from benchmark output
addmm improvements when best:
addmm_16448x512x2048: +0.14%
addmm_128x2048x2048: +0.01%
addmm_128x768x1000: +0.75%
addmm_12672x3072x768: +1.08%
addmm_512x768x32000: +0.62%
addmm_12608x384x384: +0.00%
addmm_4160x1024x4096: +0.90%
addmm_16x768x2: +0.56%
addmm_12608x3072x768: +0.09%
addmm_64x4096x1000: +2.77%
addmm_256x1024x512: +1.99%
addmm_30x256x256: +1.12%
addmm_100480x128x384: +0.91%
addmm_6400x2048x512: +0.25%
addmm_61568x1024x256: +0.08%
addmm_1x768x768: +0.93%
addmm_12544x384x384: +0.19%
addmm_128x512x1000: +0.77%
addmm_2048x128x128: +1.32%
addmm_128x3072x1000: +0.24%
addmm_7936x512x2048: +0.07%
addmm_8192x512x2048: +0.33%
addmm_64x1024x1000: +1.43%
addmm_128x2304x1000: +0.01%
addmm_32768x256x2: +0.75%
addmm_64x384x1152: +0.79%
addmm_64x640x1000: +0.01%
addmm_100480x128x128: +0.87%
addmm_1152x3072x768: +1.13%
addmm_8192x256x2048: +1.40%
addmm_4096x128x768: +0.01%
addmm_128x2560x1000: +0.01%
addmm_12544x2048x512: +0.43%
addmm_200704x24x96: +0.14%
addmm_8448x512x2048: +0.96%
addmm_50176x256x1024: +0.62%
addmm_4160x4096x1024: +0.22%
addmm_4096x768x768: +0.32%
addmm_220x2048x512: +0.56%
addmm_8x2048x1000: +1.12%
addmm_256x197951x512: +26.99%
addmm_401536x64x192: +0.60%
addmm_2040x2048x512: +0.47%
addmm_512x1024x256: +1.32%
addmm_128x4096x1000: +1.67%
addmm_12672x768x768: +0.34%
addmm_128x368x1000: +0.77%
addmm_96x1280x1000: +0.01%
addmm_12544x512x2048: +0.41%
addmm_6272x320x1280: +0.76%
addmm_12544x3072x768: +0.09%
addmm_64x384x1000: +0.39%
mm improvements when best:
mm_200704x128x512: +1.29%
mm_663552x16x16: +0.80%
mm_4096x768x768: +0.51%
mm_131072x64x31: +0.24%
mm_12544x1152x384: +0.11%
mm_128x2048x2: +0.46%
mm_262144x16x23: +0.62%
mm_50176x576x192: +0.37%
mm_131072x16x31: +0.26%
================================================================================
BENCHMARK ANALYSIS RESULTS
================================================================================
Operation: addmm
----------------------------------------
Total shapes analyzed: 247
Average Subgraph placement: 3.38
Median Subgraph placement: 2.0
Subgraph is best choice: 52/247 shapes (21.1%)
Average improvement when best: 1.15%
Median improvement when best: 0.58%
Largest improvement when best: +26.99%
Operation: bmm
----------------------------------------
Total shapes analyzed: 85
Average Subgraph placement: 24.00
Median Subgraph placement: 21.0
Subgraph is best choice: 0/85 shapes (0.0%)
Average improvement when best: N/A (never best)
Median improvement when best: N/A (never best)
Largest improvement when best: N/A (never best)
Operation: mm
----------------------------------------
Total shapes analyzed: 88
Average Subgraph placement: 15.08
Median Subgraph placement: 4.0
Subgraph is best choice: 9/88 shapes (10.2%)
Average improvement when best: 0.52%
Median improvement when best: 0.46%
Largest improvement when best: +1.29%
```
## Results
The largest shape gain, `256,197951,512`, seemed to be driven by a case where the extern kernel is way faster than the best triton configs on the recursive autotune:
```
addmm,Extern,extern_kernels.addmm,256,197951,512,0.38024500012397766
addmm,Triton,256,197951,512,32,256,16,2,2,4,2.005444049835205
addmm,Triton,256,197951,512,32,128,32,2,4,8,2.04189395904541
addmm,Triton,256,197951,512,64,128,16,2,4,8,2.1911399364471436
addmm,Triton,256,197951,512,64,128,32,2,4,8,2.496040105819702
addmm,Triton,256,197951,512,64,128,64,2,8,16,2.9306790828704834
addmm,Triton,256,197951,512,64,64,32,2,4,8,3.0347819328308105
...
```
Compared to the non-transposed autotune:
```
addmm,Subgraph,contiguous_addmm_1384,256,197951,512,0.5024129748344421
addmm,Extern,extern_kernels.addmm,256,197951,512,0.6881489753723145
addmm,Triton,256,197951,512,32,256,16,2,2,4,2.5115010738372803
addmm,Triton,256,197951,512,32,128,32,2,4,8,2.5167479515075684
addmm,Triton,256,197951,512,64,128,16,2,4,8,2.9507460594177246
addmm,Triton,256,197951,512,64,256,64,2,8,4,2.9673290252685547
addmm,Triton,256,197951,512,64,128,64,2,8,16,3.3906331062316895
addmm,Triton,256,197951,512,64,128,32,2,4,8,3.496859073638916
```
It seems to perform really well for high values of `K` vs `N` and `M`.
Testing this hypothesis with some custom shapes:
```
Parsed 64 unique shapes from benchmark output
addmm improvements when best:
addmm_128x16384x128: +0.18%
addmm_128x262144x256: +38.24%
addmm_128x200000x512: +14.76%
addmm_256x800000x128: +0.06%
addmm_131072x128x256: +0.27%
addmm_128x256x131072: +0.25%
addmm_2048x200000x64: +12.45%
mm improvements when best:
mm_128x16384x128: +0.18%
mm_128x262144x256: +38.05%
mm_128x200000x512: +9.47%
mm_256x800000x128: +0.99%
mm_512x6400000x256: +3.17%
mm_524288x64x64: +0.29%
mm_2048x200000x64: +11.19%
mm_8192x1000000x256: +34.14%
mm_128x4096x100000: +0.40%
mm_128x3072x150000: +0.27%
================================================================================
BENCHMARK ANALYSIS RESULTS
================================================================================
Operation: addmm
----------------------------------------
Total shapes analyzed: 33
Average Subgraph placement: 4.39
Median Subgraph placement: 2.0
Subgraph is best choice: 7/33 shapes (21.2%)
Average improvement when best: 9.46%
Median improvement when best: 0.27%
Largest improvement when best: +38.24%
Operation: mm
----------------------------------------
Total shapes analyzed: 30
Average Subgraph placement: 7.63
Median Subgraph placement: 2.0
Subgraph is best choice: 10/30 shapes (33.3%)
Average improvement when best: 9.81%
Median improvement when best: 2.08%
Largest improvement when best: +38.05%
```
## Conclusion
Contiguous Subgraph Decompositionseems worthwhile for `mm` and `addmm`, but not `bmm`, and has a very large improvment on low `M`, low `N`, and high `K` shapes.
Data gathering scripts:
https://gist.github.com/exclamaforte/4a896c064d301b27bf5ca0a4f8fc3866
## Test Plan:
New unit tests.
Differential Revision: D80771648
Pull Request resolved: pytorch#161241
Approved by: https://github.com/eellison
Summary
Adds a subgraph decomposition for addmm and mm that performs well on large
Kcompared toMandN, and functions well as an alternative tosplit-kon AMD (transposed only), which does not support AMD currently.Background
On AMD (MI300x), for a matmul A * B, if B is non-contiguous, the resulting matmul is quite a bit slower.
For example:
is a lot slower than:
This PR adds a subgraph decomposition to test out whether making B contiguous is faster than just using the normal kernels.
Data
I ran this on unique non-contiguous shapes from torchbench/huggingface and got these speedups:
Results
The largest shape gain,
256,197951,512, seemed to be driven by a case where the extern kernel is way faster than the best triton configs on the recursive autotune:Compared to the non-transposed autotune:
It seems to perform really well for high values of
KvsNandM.Testing this hypothesis with some custom shapes:
Conclusion
Contiguous Subgraph Decompositionseems worthwhile for
mmandaddmm, but notbmm, and has a very large improvment on lowM, lowN, and highKshapes.Data gathering scripts:
https://gist.github.com/exclamaforte/4a896c064d301b27bf5ca0a4f8fc3866
Test Plan:
New unit tests.
Differential Revision: D80771648
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben