[ROCm][Inductor] Enable pipelining for FlexAttention#176676
[ROCm][Inductor] Enable pipelining for FlexAttention#176676nithinsubbiah wants to merge 3 commits intopytorch:mainfrom
Conversation
Adds an additional tile size `256` to tuning config for Flex Attention performance on Triton for AMD backend. This provides significant performance boost (~2x) across a board range of shapes particularly for larger sequence lengths. This performance boost will be realized when developers provide `max_autotune=True` option to `torch.compile`.
Setting `num_stages=2` in config enables double-buffering that results in ~20-25% improvement in Flex Attention performance.
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/176676
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (9 Unrelated Failures)As of commit 2eac6ac with merge base c1943cf ( FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
|
@pytorchbot label "topic: not user facing" |
Signed-off-by: nithinsubbiah <nithinsubbiah@gmail.com>
5cee8c7 to
2eac6ac
Compare
|
I removed the addition of 256 to |
|
@pytorchbot merge |
|
Pull workflow has not been scheduled for the PR yet. It could be because author doesn't have permissions to run those or skip-checks keywords were added to PR/commits, aborting merge. Please get/give approval for the workflows and/or remove skip ci decorators before next merge attempt. If you think this is a mistake, please contact PyTorch Dev Infra. |
|
To add the ciflow label This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows. |
|
To add the ciflow label This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows. |
|
To add the ciflow label This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows. |
|
To add the ciflow label This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows. |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: Comment with id 4033515126 not found Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Changing the `num_stages` value from 1 to 2 enables more efficient pipelining in Triton backend which improves the performance. Here's some benchmark numbers for comparison run on MI350X. | Attn Type | Shape (B,Hq,M,Hkv,N,D) | stages=1 (μs) | stages=2 (μs) | Speedup | |----------------|----------------------------------|----------------|----------------|---------| | causal | (2, 16, 512, 16, 512, 64) | 37.6 | 35.8 | 1.05x | | causal | (2, 16, 512, 2, 512, 128) | 35.7 | 35.1 | 1.02x | | causal | (2, 16, 1024, 16, 1024, 64) | 39.5 | 31.4 | 1.26x | | causal | (2, 16, 4096, 16, 4096, 128) | 680.3 | 580.6 | 1.17x | | causal | (2, 16, 4096, 2, 4096, 64) | 259.0 | 238.4 | 1.09x | | noop | (8, 16, 1024, 16, 1024, 128) | 196.2 | 183.3 | 1.07x | | causal | (8, 16, 1024, 2, 1024, 64) | 79.7 | 75.5 | 1.06x | | alibi | (8, 16, 4096, 16, 4096, 64) | 2017.7 | 1727.3 | 1.17x | | causal | (8, 16, 4096, 16, 4096, 128) | 2686.0 | 2258.7 | 1.19x | | sliding_window | (8, 16, 4096, 2, 4096, 64) | 610.4 | 559.3 | 1.09x | | causal | (16, 16, 512, 16, 512, 128) | 111.6 | 99.0 | 1.13x | | alibi | (16, 16, 1024, 2, 1024, 128) | 391.6 | 335.3 | 1.17x | | causal | (16, 16, 1024, 16, 1024, 64) | 163.6 | 142.6 | 1.15x | | noop | (16, 16, 4096, 16, 4096, 128) | 6260.5 | 5130.3 | 1.22x | | causal | (16, 16, 4096, 2, 4096, 64) | 2084.5 | 1780.5 | 1.17x | | causal | (1, 32, 16384, 4, 16384, 64) | 2687.9 | 2472.8 | 1.09x | | **Geo-mean** | | | | **1.13x** | All configs: `num_warps=4`, `dtype=bfloat16`, fwd only. Benchmarked with `attention-gym` on ROCm. Pull Request resolved: pytorch#176676 Approved by: https://github.com/drisspg, https://github.com/jeffdaily
Changing the
num_stagesvalue from 1 to 2 enables more efficient pipelining in Triton backend which improves the performance. Here's some benchmark numbers for comparison run on MI350X.All configs:
num_warps=4,dtype=bfloat16, fwd only. Benchmarked withattention-gymon ROCm.cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @jataylo @hongxiayang @naromero77amd @pragupta @jerrymannil @xinyazhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben