Skip to content

[ROCm][Inductor] Enable pipelining for FlexAttention#176676

Closed
nithinsubbiah wants to merge 3 commits intopytorch:mainfrom
nithinsubbiah:flexattn_perf
Closed

[ROCm][Inductor] Enable pipelining for FlexAttention#176676
nithinsubbiah wants to merge 3 commits intopytorch:mainfrom
nithinsubbiah:flexattn_perf

Conversation

@nithinsubbiah
Copy link
Copy Markdown
Contributor

@nithinsubbiah nithinsubbiah commented Mar 6, 2026

Changing the num_stages value from 1 to 2 enables more efficient pipelining in Triton backend which improves the performance. Here's some benchmark numbers for comparison run on MI350X.

Attn Type Shape (B,Hq,M,Hkv,N,D) stages=1 (μs) stages=2 (μs) Speedup
causal (2, 16, 512, 16, 512, 64) 37.6 35.8 1.05x
causal (2, 16, 512, 2, 512, 128) 35.7 35.1 1.02x
causal (2, 16, 1024, 16, 1024, 64) 39.5 31.4 1.26x
causal (2, 16, 4096, 16, 4096, 128) 680.3 580.6 1.17x
causal (2, 16, 4096, 2, 4096, 64) 259.0 238.4 1.09x
noop (8, 16, 1024, 16, 1024, 128) 196.2 183.3 1.07x
causal (8, 16, 1024, 2, 1024, 64) 79.7 75.5 1.06x
alibi (8, 16, 4096, 16, 4096, 64) 2017.7 1727.3 1.17x
causal (8, 16, 4096, 16, 4096, 128) 2686.0 2258.7 1.19x
sliding_window (8, 16, 4096, 2, 4096, 64) 610.4 559.3 1.09x
causal (16, 16, 512, 16, 512, 128) 111.6 99.0 1.13x
alibi (16, 16, 1024, 2, 1024, 128) 391.6 335.3 1.17x
causal (16, 16, 1024, 16, 1024, 64) 163.6 142.6 1.15x
noop (16, 16, 4096, 16, 4096, 128) 6260.5 5130.3 1.22x
causal (16, 16, 4096, 2, 4096, 64) 2084.5 1780.5 1.17x
causal (1, 32, 16384, 4, 16384, 64) 2687.9 2472.8 1.09x
Geo-mean 1.13x

All configs: num_warps=4, dtype=bfloat16, fwd only. Benchmarked with attention-gym on ROCm.

cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @jataylo @hongxiayang @naromero77amd @pragupta @jerrymannil @xinyazhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben

Adds an additional tile size `256` to tuning config for Flex Attention performance on Triton for AMD backend. This provides significant performance boost (~2x) across a board range of shapes particularly for larger sequence lengths. This performance boost will be realized when developers provide `max_autotune=True` option to `torch.compile`.
Setting `num_stages=2` in config enables double-buffering that results in ~20-25% improvement in Flex Attention performance.
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 6, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/176676

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (9 Unrelated Failures)

As of commit 2eac6ac with merge base c1943cf (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added module: inductor module: rocm AMD GPU support for Pytorch labels Mar 6, 2026
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 6, 2026

This PR needs a release notes: label

If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@nithinsubbiah
Copy link
Copy Markdown
Contributor Author

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Mar 6, 2026
@albanD albanD added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Mar 6, 2026
@albanD albanD requested a review from eellison March 6, 2026 14:11
@eellison eellison requested a review from drisspg March 6, 2026 16:24
@pytorch-bot pytorch-bot bot added ciflow/inductor ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 labels Mar 7, 2026
@pytorch-bot pytorch-bot bot removed ciflow/inductor ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 labels Mar 9, 2026
@nithinsubbiah nithinsubbiah requested a review from drisspg March 9, 2026 21:50
Signed-off-by: nithinsubbiah <nithinsubbiah@gmail.com>
@nithinsubbiah nithinsubbiah changed the title [ROCm][Inductor] Enable pipelining and efficient tile size for FlexAttention [ROCm][Inductor] Enable pipelining for FlexAttention Mar 9, 2026
@nithinsubbiah
Copy link
Copy Markdown
Contributor Author

I removed the addition of 256 to BLOCK_M tuning config since I see regression for some shapes. Ideally, there shouldn't be any regression since the autotuner should be able to skip non-optimal configs but there seems to be an issue there. Will investigate this in a follow-up PR

@nithinsubbiah
Copy link
Copy Markdown
Contributor Author

@pytorchbot merge

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 10, 2026

Pull workflow has not been scheduled for the PR yet. It could be because author doesn't have permissions to run those or skip-checks keywords were added to PR/commits, aborting merge. Please get/give approval for the workflows and/or remove skip ci decorators before next merge attempt. If you think this is a mistake, please contact PyTorch Dev Infra.

@jataylo jataylo added ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 ciflow/rocm-mi355 Trigger "default" config CI on ROCm MI355 runners ciflow/rocm-navi31 Trigger "default" config CI on ROCm Navi31 ciflow/rocm-mi200 Trigger "default" config CI on ROCm MI200 labels Mar 10, 2026
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 10, 2026

To add the ciflow label ciflow/rocm-mi300 please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 10, 2026

To add the ciflow label ciflow/rocm-navi31 please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 10, 2026

To add the ciflow label ciflow/rocm-mi355 please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 10, 2026

To add the ciflow label ciflow/rocm-mi200 please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

@pytorch-bot pytorch-bot bot removed ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 ciflow/rocm-navi31 Trigger "default" config CI on ROCm Navi31 ciflow/rocm-mi355 Trigger "default" config CI on ROCm MI355 runners ciflow/rocm-mi200 Trigger "default" config CI on ROCm MI200 labels Mar 10, 2026
@jataylo jataylo added ciflow/trunk Trigger trunk jobs on your pull request ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 ciflow/rocm-navi31 Trigger "default" config CI on ROCm Navi31 ciflow/rocm-mi200 Trigger "default" config CI on ROCm MI200 ciflow/inductor-rocm-mi300 Trigger "inductor" config CI on ROCm MI300/MI325 labels Mar 10, 2026
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge failed

Reason: Comment with id 4033515126 not found

Details for Dev Infra team Raised by workflow job

@nithinsubbiah
Copy link
Copy Markdown
Contributor Author

@drisspg @jataylo Could we merge this PR? Failures are unrelated and caused by flaky tests, pytorch-bot reports that it can be merged

@jataylo jataylo requested a review from jeffdaily March 11, 2026 11:22
@jataylo
Copy link
Copy Markdown
Collaborator

jataylo commented Mar 12, 2026

@pytorchbot merge

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

EmanueleCoradin pushed a commit to EmanueleCoradin/pytorch that referenced this pull request Mar 30, 2026
Changing the `num_stages` value from 1 to 2 enables more efficient pipelining in Triton backend which improves the performance. Here's some benchmark numbers for comparison run on MI350X.

  | Attn Type      | Shape (B,Hq,M,Hkv,N,D)         | stages=1 (μs) | stages=2 (μs) | Speedup |
  |----------------|----------------------------------|----------------|----------------|---------|
  | causal         | (2, 16, 512, 16, 512, 64)       | 37.6           | 35.8           | 1.05x   |
  | causal         | (2, 16, 512, 2, 512, 128)       | 35.7           | 35.1           | 1.02x   |
  | causal         | (2, 16, 1024, 16, 1024, 64)     | 39.5           | 31.4           | 1.26x   |
  | causal         | (2, 16, 4096, 16, 4096, 128)    | 680.3          | 580.6          | 1.17x   |
  | causal         | (2, 16, 4096, 2, 4096, 64)      | 259.0          | 238.4          | 1.09x   |
  | noop           | (8, 16, 1024, 16, 1024, 128)    | 196.2          | 183.3          | 1.07x   |
  | causal         | (8, 16, 1024, 2, 1024, 64)      | 79.7           | 75.5           | 1.06x   |
  | alibi          | (8, 16, 4096, 16, 4096, 64)     | 2017.7         | 1727.3         | 1.17x   |
  | causal         | (8, 16, 4096, 16, 4096, 128)    | 2686.0         | 2258.7         | 1.19x   |
  | sliding_window | (8, 16, 4096, 2, 4096, 64)      | 610.4          | 559.3          | 1.09x   |
  | causal         | (16, 16, 512, 16, 512, 128)     | 111.6          | 99.0           | 1.13x   |
  | alibi          | (16, 16, 1024, 2, 1024, 128)    | 391.6          | 335.3          | 1.17x   |
  | causal         | (16, 16, 1024, 16, 1024, 64)    | 163.6          | 142.6          | 1.15x   |
  | noop           | (16, 16, 4096, 16, 4096, 128)   | 6260.5         | 5130.3         | 1.22x   |
  | causal         | (16, 16, 4096, 2, 4096, 64)     | 2084.5         | 1780.5         | 1.17x   |
  | causal         | (1, 32, 16384, 4, 16384, 64)    | 2687.9         | 2472.8         | 1.09x   |
  | **Geo-mean**   |                                  |                |                | **1.13x** |

  All configs: `num_warps=4`, `dtype=bfloat16`, fwd only. Benchmarked with `attention-gym` on ROCm.

Pull Request resolved: pytorch#176676
Approved by: https://github.com/drisspg, https://github.com/jeffdaily
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor-rocm-mi300 Trigger "inductor" config CI on ROCm MI300/MI325 ciflow/rocm-mi200 Trigger "default" config CI on ROCm MI200 ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 ciflow/rocm-navi31 Trigger "default" config CI on ROCm Navi31 ciflow/trunk Trigger trunk jobs on your pull request Merged module: inductor module: rocm AMD GPU support for Pytorch open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants