Skip to content

[flex attention][triton pin] triton_helpers shim for TMA apis#154858

Closed
davidberard98 wants to merge 4 commits intogh/davidberard98/359/basefrom
gh/davidberard98/359/head
Closed

[flex attention][triton pin] triton_helpers shim for TMA apis#154858
davidberard98 wants to merge 4 commits intogh/davidberard98/359/basefrom
gh/davidberard98/359/head

Conversation

@davidberard98
Copy link
Contributor

@davidberard98 davidberard98 commented Jun 2, 2025

Stack from ghstack (oldest at bottom):

Triton 3.4 will remove the experimental TMA apis: triton-lang/triton#6488

To allow compatibility across different triton versions, we implement a shim layer which calls the new API if available, and otherwise falls back to the experimental API.

Test: python test/inductor/test_flex_attention.py TestFlexAttentionCUDA.test_GQA_causal_mask_cuda which previously fails w/ triton-lang/tritoncda4229558c5dca7f7c4734bedd3e596ebcae0b8, but now passes.

Note: we'll need to apply this for other things in inductor, this just does it for flex attention.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov

Triton 3.4 will remove the experimental TMA apis: triton-lang/triton#6488

To allow compatibility across different triton versions, we implement a shim layer which calls the new API if available, and otherwise falls back to the experimental API.

Test: `python test/inductor/test_flex_attention.py TestFlexAttentionCUDA.test_GQA_causal_mask_cuda` which previously fails w/ triton-lang/triton@cda4229, but now passes.

Note: we'll need to apply this for other things in inductor, this just does it for flex attention.

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Jun 2, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/154858

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 6e91f74 with merge base 0d0058d (image):

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

…pis"

Triton 3.4 will remove the experimental TMA apis: triton-lang/triton#6488

To allow compatibility across different triton versions, we implement a shim layer which calls the new API if available, and otherwise falls back to the experimental API.

Test: `python test/inductor/test_flex_attention.py TestFlexAttentionCUDA.test_GQA_causal_mask_cuda` which previously fails w/ triton-lang/tritoncda4229558c5dca7f7c4734bedd3e596ebcae0b8, but now passes.

Note: we'll need to apply this for other things in inductor, this just does it for flex attention.

[ghstack-poisoned]
…pis"

Triton 3.4 will remove the experimental TMA apis: triton-lang/triton#6488

To allow compatibility across different triton versions, we implement a shim layer which calls the new API if available, and otherwise falls back to the experimental API.

Test: `python test/inductor/test_flex_attention.py TestFlexAttentionCUDA.test_GQA_causal_mask_cuda` which previously fails w/ triton-lang/tritoncda4229558c5dca7f7c4734bedd3e596ebcae0b8, but now passes.

Note: we'll need to apply this for other things in inductor, this just does it for flex attention.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
@davidberard98 davidberard98 requested review from NikhilAPatel and removed request for NikhilAPatel June 2, 2025 19:45
…pis"

Triton 3.4 will remove the experimental TMA apis: triton-lang/triton#6488

To allow compatibility across different triton versions, we implement a shim layer which calls the new API if available, and otherwise falls back to the experimental API.

Test: `python test/inductor/test_flex_attention.py TestFlexAttentionCUDA.test_GQA_causal_mask_cuda` which previously fails w/ triton-lang/tritoncda4229558c5dca7f7c4734bedd3e596ebcae0b8, but now passes.

Note: we'll need to apply this for other things in inductor, this just does it for flex attention.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
davidberard98 added a commit that referenced this pull request Jun 2, 2025
Triton 3.4 will remove the experimental TMA apis: triton-lang/triton#6488

To allow compatibility across different triton versions, we implement a shim layer which calls the new API if available, and otherwise falls back to the experimental API.

Test: `python test/inductor/test_flex_attention.py TestFlexAttentionCUDA.test_GQA_causal_mask_cuda` which previously fails w/ triton-lang/tritoncda4229558c5dca7f7c4734bedd3e596ebcae0b8, but now passes.

Note: we'll need to apply this for other things in inductor, this just does it for flex attention.

ghstack-source-id: ec2a0a1
Pull Request resolved: #154858
Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good

@davidberard98
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jun 3, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit to Eliasj42/pytorch that referenced this pull request Jun 3, 2025
…h#154858)

Triton 3.4 will remove the experimental TMA apis: triton-lang/triton#6488

To allow compatibility across different triton versions, we implement a shim layer which calls the new API if available, and otherwise falls back to the experimental API.

Test: `python test/inductor/test_flex_attention.py TestFlexAttentionCUDA.test_GQA_causal_mask_cuda` which previously fails w/ triton-lang/tritoncda4229558c5dca7f7c4734bedd3e596ebcae0b8, but now passes.

Note: we'll need to apply this for other things in inductor, this just does it for flex attention.

Pull Request resolved: pytorch#154858
Approved by: https://github.com/NikhilAPatel, https://github.com/drisspg
iupaikov-amd pushed a commit to ROCm/pytorch that referenced this pull request Jun 4, 2025
…h#154858)

Triton 3.4 will remove the experimental TMA apis: triton-lang/triton#6488

To allow compatibility across different triton versions, we implement a shim layer which calls the new API if available, and otherwise falls back to the experimental API.

Test: `python test/inductor/test_flex_attention.py TestFlexAttentionCUDA.test_GQA_causal_mask_cuda` which previously fails w/ triton-lang/tritoncda4229558c5dca7f7c4734bedd3e596ebcae0b8, but now passes.

Note: we'll need to apply this for other things in inductor, this just does it for flex attention.

Pull Request resolved: pytorch#154858
Approved by: https://github.com/NikhilAPatel, https://github.com/drisspg
angelayi pushed a commit to angelayi/pytorch that referenced this pull request Jun 5, 2025
…h#154858)

Triton 3.4 will remove the experimental TMA apis: triton-lang/triton#6488

To allow compatibility across different triton versions, we implement a shim layer which calls the new API if available, and otherwise falls back to the experimental API.

Test: `python test/inductor/test_flex_attention.py TestFlexAttentionCUDA.test_GQA_causal_mask_cuda` which previously fails w/ triton-lang/tritoncda4229558c5dca7f7c4734bedd3e596ebcae0b8, but now passes.

Note: we'll need to apply this for other things in inductor, this just does it for flex attention.

Pull Request resolved: pytorch#154858
Approved by: https://github.com/NikhilAPatel, https://github.com/drisspg
davidberard98 added a commit that referenced this pull request Jun 5, 2025
Follow-up to #154858.

Triton 3.4 will provide a different API for TMA compared to Triton 3.3; the TMA shim in triton_helpers dispatches to the correct API.

This PR updates the TMA usage in mm.py and mm_scaled_grouped.py to use the TMA shim so that they will work with either Triton version.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
davidberard98 added a commit that referenced this pull request Jun 9, 2025
…_scaled_grouped"


Follow-up to #154858.

Triton 3.4 will provide a different API for TMA compared to Triton 3.3; the TMA shim in triton_helpers dispatches to the correct API.

This PR updates the TMA usage in mm.py and mm_scaled_grouped.py to use the TMA shim so that they will work with either Triton version.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
davidberard98 added a commit that referenced this pull request Jun 9, 2025
Follow-up to #154858.

Triton 3.4 will provide a different API for TMA compared to Triton 3.3; the TMA shim in triton_helpers dispatches to the correct API.

This PR updates the TMA usage in mm.py and mm_scaled_grouped.py to use the TMA shim so that they will work with either Triton version.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
davidberard98 added a commit that referenced this pull request Jun 9, 2025
…_scaled_grouped"


Follow-up to #154858.

Triton 3.4 will provide a different API for TMA compared to Triton 3.3; the TMA shim in triton_helpers dispatches to the correct API.

This PR updates the TMA usage in mm.py and mm_scaled_grouped.py to use the TMA shim so that they will work with either Triton version.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
davidberard98 added a commit that referenced this pull request Jun 9, 2025
Follow-up to #154858.

Triton 3.4 will provide a different API for TMA compared to Triton 3.3; the TMA shim in triton_helpers dispatches to the correct API.

This PR updates the TMA usage in mm.py and mm_scaled_grouped.py to use the TMA shim so that they will work with either Triton version.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
davidberard98 added a commit that referenced this pull request Jun 9, 2025
…_scaled_grouped"


Follow-up to #154858.

Triton 3.4 will provide a different API for TMA compared to Triton 3.3; the TMA shim in triton_helpers dispatches to the correct API.

This PR updates the TMA usage in mm.py and mm_scaled_grouped.py to use the TMA shim so that they will work with either Triton version.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
davidberard98 added a commit that referenced this pull request Jun 9, 2025
Follow-up to #154858.

Triton 3.4 will provide a different API for TMA compared to Triton 3.3; the TMA shim in triton_helpers dispatches to the correct API.

This PR updates the TMA usage in mm.py and mm_scaled_grouped.py to use the TMA shim so that they will work with either Triton version.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
davidberard98 added a commit that referenced this pull request Jun 9, 2025
… mm.py support"


Follow-up to #154858.

Triton 3.4 will provide a different API for TMA compared to Triton 3.3; the TMA shim in triton_helpers dispatches to the correct API.

First, this refactors the TMA shim to drop args that aren't supported from Triton 3.2 to Triton 3.4: in particular, strides (Triton 3.2 version doesn't accept non-contiguous inputs, so we just infer contiguous strides in Triton 3.4) and element_ty (Triton 3.4 doesn't support this arg, so in Triton 3.2 we just infer it from base_ptr).

Second, this updates mm.py to use the TMA shim.

mm_scaled_grouped.py still needs to be updated, but requires some work around recent changes in #150944.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
davidberard98 added a commit that referenced this pull request Jun 9, 2025
Follow-up to #154858.

Triton 3.4 will provide a different API for TMA compared to Triton 3.3; the TMA shim in triton_helpers dispatches to the correct API.

First, this refactors the TMA shim to drop args that aren't supported from Triton 3.2 to Triton 3.4: in particular, strides (Triton 3.2 version doesn't accept non-contiguous inputs, so we just infer contiguous strides in Triton 3.4) and element_ty (Triton 3.4 doesn't support this arg, so in Triton 3.2 we just infer it from base_ptr).

Second, this updates mm.py to use the TMA shim.

mm_scaled_grouped.py still needs to be updated, but requires some work around recent changes in #150944.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
davidberard98 added a commit that referenced this pull request Jun 9, 2025
… mm, mm_scaled_grouped support"


Follow-up to #154858.

Triton 3.4 will provide a different API for TMA compared to Triton 3.3; the TMA shim in triton_helpers dispatches to the correct API.

First, this refactors the TMA shim to drop args that aren't supported from Triton 3.2 to Triton 3.4: in particular, strides (Triton 3.2 version doesn't accept non-contiguous inputs, so we just infer contiguous strides in Triton 3.4) and element_ty (Triton 3.4 doesn't support this arg, so in Triton 3.2 we just infer it from base_ptr).

Second, this updates mm.py & mm_scaled_grouped.py to use the TMA shim.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
davidberard98 added a commit that referenced this pull request Jun 9, 2025
…rouped support"


Follow-up to #154858.

Triton 3.4 will provide a different API for TMA compared to Triton 3.3; the TMA shim in triton_helpers dispatches to the correct API.

First, this refactors the TMA shim to drop args that aren't supported from Triton 3.2 to Triton 3.4: in particular, strides (Triton 3.2 version doesn't accept non-contiguous inputs, so we just infer contiguous strides in Triton 3.4) and element_ty (Triton 3.4 doesn't support this arg, so in Triton 3.2 we just infer it from base_ptr).

Second, this updates mm.py & mm_scaled_grouped.py to use the TMA shim.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request Jun 10, 2025
…ort (#155182)

Follow-up to #154858.

Triton 3.4 will provide a different API for TMA compared to Triton 3.3; the TMA shim in triton_helpers dispatches to the correct API.

First, this refactors the TMA shim to drop args that aren't supported from Triton 3.2 to Triton 3.4: in particular, strides (Triton 3.2 version doesn't accept non-contiguous inputs, so we just infer contiguous strides in Triton 3.4) and element_ty (Triton 3.4 doesn't support this arg, so in Triton 3.2 we just infer it from base_ptr).

Second, this updates mm.py & mm_scaled_grouped.py to use the TMA shim.

Differential Revision: [D76318784](https://our.internmc.facebook.com/intern/diff/D76318784)
Pull Request resolved: #155182
Approved by: https://github.com/drisspg
davidberard98 added a commit that referenced this pull request Jun 10, 2025
…#154858)"

This reverts commit ea7b233.

It fails internal tests in fbcode, but they weren't running so we didn't get signal

Reverting w/ a PR/diff because the PR has been landed for ~1 week - too old to revert directly from internal.

[ghstack-poisoned]
davidberard98 added a commit that referenced this pull request Jun 10, 2025
…#154858)"

This reverts commit ea7b233.

It fails internal tests in fbcode, but they weren't running so we didn't get signal

Reverting w/ a PR/diff because the PR has been landed for ~1 week - too old to revert directly from internal.

ghstack-source-id: 6bf6fd8
Pull Request resolved: #155640
pytorchmergebot pushed a commit that referenced this pull request Jun 11, 2025
…#154858)" (#155640)

This reverts commit ea7b233.

It fails internal tests in fbcode, but they weren't running so we didn't get signal

Reverting w/ a PR/diff because the PR has been landed for ~1 week - too old to revert directly from internal.

Differential Revision: [D76380887](https://our.internmc.facebook.com/intern/diff/D76380887)
Pull Request resolved: #155640
Approved by: https://github.com/nmacchioni, https://github.com/danzimm
davidberard98 added a commit that referenced this pull request Jun 11, 2025
davidberard98 added a commit that referenced this pull request Jun 11, 2025
…py templates"


Triton 3.4 will remove the experimental TMA APIs: triton-lang/triton#6488

For mm.py templates, this PR adds support for using the new APIs when they are available (and otherwise falls back to the experimental APIs).

For flex_attention, we'll remove TMA support for Triton 3.2 and 3.3 (versions of triton that don't have the new API).

For mm_scaled_grouped.py, #150944 will remove TMA support for Triton 3.2.

Note: we attempted this earlier with #154858, but this broke TMA usage in Triton 3.2.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

Differential Revision: [D76444471](https://our.internmc.facebook.com/intern/diff/D76444471)

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request Jun 12, 2025
#155723)

Triton 3.4 will remove the experimental TMA APIs: triton-lang/triton#6488

For mm.py templates, this PR adds support for using the new APIs when they are available (and otherwise falls back to the experimental APIs).

For flex_attention, we'll remove TMA support for Triton 3.2 and 3.3 (versions of triton that don't have the new API).

For mm_scaled_grouped.py, #150944 will remove TMA support for Triton 3.2.

Note: we attempted this earlier with #154858, but this broke TMA usage in Triton 3.2.

Differential Revision: [D76444471](https://our.internmc.facebook.com/intern/diff/D76444471)
Pull Request resolved: #155723
Approved by: https://github.com/NikhilAPatel
@github-actions github-actions bot deleted the gh/davidberard98/359/head branch July 4, 2025 02:20
ngimel added a commit that referenced this pull request Nov 20, 2025
… week pointer of tensor storage object (#154858)"

This reverts commit b9b84d8.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Upstream Triton] experimental_tensormap_fenceproxy_acquire Triton has removed the experimental descriptor API

4 participants