[flex attention][triton pin] triton_helpers shim for TMA apis#154858
Closed
davidberard98 wants to merge 4 commits intogh/davidberard98/359/basefrom
Closed
[flex attention][triton pin] triton_helpers shim for TMA apis#154858davidberard98 wants to merge 4 commits intogh/davidberard98/359/basefrom
davidberard98 wants to merge 4 commits intogh/davidberard98/359/basefrom
Conversation
Triton 3.4 will remove the experimental TMA apis: triton-lang/triton#6488 To allow compatibility across different triton versions, we implement a shim layer which calls the new API if available, and otherwise falls back to the experimental API. Test: `python test/inductor/test_flex_attention.py TestFlexAttentionCUDA.test_GQA_causal_mask_cuda` which previously fails w/ triton-lang/triton@cda4229, but now passes. Note: we'll need to apply this for other things in inductor, this just does it for flex attention. [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/154858
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 6e91f74 with merge base 0d0058d ( UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…pis" Triton 3.4 will remove the experimental TMA apis: triton-lang/triton#6488 To allow compatibility across different triton versions, we implement a shim layer which calls the new API if available, and otherwise falls back to the experimental API. Test: `python test/inductor/test_flex_attention.py TestFlexAttentionCUDA.test_GQA_causal_mask_cuda` which previously fails w/ triton-lang/tritoncda4229558c5dca7f7c4734bedd3e596ebcae0b8, but now passes. Note: we'll need to apply this for other things in inductor, this just does it for flex attention. [ghstack-poisoned]
This was
linked to
issues
Jun 2, 2025
…pis" Triton 3.4 will remove the experimental TMA apis: triton-lang/triton#6488 To allow compatibility across different triton versions, we implement a shim layer which calls the new API if available, and otherwise falls back to the experimental API. Test: `python test/inductor/test_flex_attention.py TestFlexAttentionCUDA.test_GQA_causal_mask_cuda` which previously fails w/ triton-lang/tritoncda4229558c5dca7f7c4734bedd3e596ebcae0b8, but now passes. Note: we'll need to apply this for other things in inductor, this just does it for flex attention. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
NikhilAPatel
approved these changes
Jun 2, 2025
…pis" Triton 3.4 will remove the experimental TMA apis: triton-lang/triton#6488 To allow compatibility across different triton versions, we implement a shim layer which calls the new API if available, and otherwise falls back to the experimental API. Test: `python test/inductor/test_flex_attention.py TestFlexAttentionCUDA.test_GQA_causal_mask_cuda` which previously fails w/ triton-lang/tritoncda4229558c5dca7f7c4734bedd3e596ebcae0b8, but now passes. Note: we'll need to apply this for other things in inductor, this just does it for flex attention. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
davidberard98
added a commit
that referenced
this pull request
Jun 2, 2025
Triton 3.4 will remove the experimental TMA apis: triton-lang/triton#6488 To allow compatibility across different triton versions, we implement a shim layer which calls the new API if available, and otherwise falls back to the experimental API. Test: `python test/inductor/test_flex_attention.py TestFlexAttentionCUDA.test_GQA_causal_mask_cuda` which previously fails w/ triton-lang/tritoncda4229558c5dca7f7c4734bedd3e596ebcae0b8, but now passes. Note: we'll need to apply this for other things in inductor, this just does it for flex attention. ghstack-source-id: ec2a0a1 Pull Request resolved: #154858
Contributor
Author
|
@pytorchbot merge |
Collaborator
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
pytorchmergebot
pushed a commit
to Eliasj42/pytorch
that referenced
this pull request
Jun 3, 2025
…h#154858) Triton 3.4 will remove the experimental TMA apis: triton-lang/triton#6488 To allow compatibility across different triton versions, we implement a shim layer which calls the new API if available, and otherwise falls back to the experimental API. Test: `python test/inductor/test_flex_attention.py TestFlexAttentionCUDA.test_GQA_causal_mask_cuda` which previously fails w/ triton-lang/tritoncda4229558c5dca7f7c4734bedd3e596ebcae0b8, but now passes. Note: we'll need to apply this for other things in inductor, this just does it for flex attention. Pull Request resolved: pytorch#154858 Approved by: https://github.com/NikhilAPatel, https://github.com/drisspg
iupaikov-amd
pushed a commit
to ROCm/pytorch
that referenced
this pull request
Jun 4, 2025
…h#154858) Triton 3.4 will remove the experimental TMA apis: triton-lang/triton#6488 To allow compatibility across different triton versions, we implement a shim layer which calls the new API if available, and otherwise falls back to the experimental API. Test: `python test/inductor/test_flex_attention.py TestFlexAttentionCUDA.test_GQA_causal_mask_cuda` which previously fails w/ triton-lang/tritoncda4229558c5dca7f7c4734bedd3e596ebcae0b8, but now passes. Note: we'll need to apply this for other things in inductor, this just does it for flex attention. Pull Request resolved: pytorch#154858 Approved by: https://github.com/NikhilAPatel, https://github.com/drisspg
angelayi
pushed a commit
to angelayi/pytorch
that referenced
this pull request
Jun 5, 2025
…h#154858) Triton 3.4 will remove the experimental TMA apis: triton-lang/triton#6488 To allow compatibility across different triton versions, we implement a shim layer which calls the new API if available, and otherwise falls back to the experimental API. Test: `python test/inductor/test_flex_attention.py TestFlexAttentionCUDA.test_GQA_causal_mask_cuda` which previously fails w/ triton-lang/tritoncda4229558c5dca7f7c4734bedd3e596ebcae0b8, but now passes. Note: we'll need to apply this for other things in inductor, this just does it for flex attention. Pull Request resolved: pytorch#154858 Approved by: https://github.com/NikhilAPatel, https://github.com/drisspg
davidberard98
added a commit
that referenced
this pull request
Jun 5, 2025
Follow-up to #154858. Triton 3.4 will provide a different API for TMA compared to Triton 3.3; the TMA shim in triton_helpers dispatches to the correct API. This PR updates the TMA usage in mm.py and mm_scaled_grouped.py to use the TMA shim so that they will work with either Triton version. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
davidberard98
added a commit
that referenced
this pull request
Jun 9, 2025
…_scaled_grouped" Follow-up to #154858. Triton 3.4 will provide a different API for TMA compared to Triton 3.3; the TMA shim in triton_helpers dispatches to the correct API. This PR updates the TMA usage in mm.py and mm_scaled_grouped.py to use the TMA shim so that they will work with either Triton version. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
davidberard98
added a commit
that referenced
this pull request
Jun 9, 2025
Follow-up to #154858. Triton 3.4 will provide a different API for TMA compared to Triton 3.3; the TMA shim in triton_helpers dispatches to the correct API. This PR updates the TMA usage in mm.py and mm_scaled_grouped.py to use the TMA shim so that they will work with either Triton version. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
davidberard98
added a commit
that referenced
this pull request
Jun 9, 2025
…_scaled_grouped" Follow-up to #154858. Triton 3.4 will provide a different API for TMA compared to Triton 3.3; the TMA shim in triton_helpers dispatches to the correct API. This PR updates the TMA usage in mm.py and mm_scaled_grouped.py to use the TMA shim so that they will work with either Triton version. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
davidberard98
added a commit
that referenced
this pull request
Jun 9, 2025
Follow-up to #154858. Triton 3.4 will provide a different API for TMA compared to Triton 3.3; the TMA shim in triton_helpers dispatches to the correct API. This PR updates the TMA usage in mm.py and mm_scaled_grouped.py to use the TMA shim so that they will work with either Triton version. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
davidberard98
added a commit
that referenced
this pull request
Jun 9, 2025
…_scaled_grouped" Follow-up to #154858. Triton 3.4 will provide a different API for TMA compared to Triton 3.3; the TMA shim in triton_helpers dispatches to the correct API. This PR updates the TMA usage in mm.py and mm_scaled_grouped.py to use the TMA shim so that they will work with either Triton version. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
davidberard98
added a commit
that referenced
this pull request
Jun 9, 2025
Follow-up to #154858. Triton 3.4 will provide a different API for TMA compared to Triton 3.3; the TMA shim in triton_helpers dispatches to the correct API. This PR updates the TMA usage in mm.py and mm_scaled_grouped.py to use the TMA shim so that they will work with either Triton version. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
davidberard98
added a commit
that referenced
this pull request
Jun 9, 2025
… mm.py support" Follow-up to #154858. Triton 3.4 will provide a different API for TMA compared to Triton 3.3; the TMA shim in triton_helpers dispatches to the correct API. First, this refactors the TMA shim to drop args that aren't supported from Triton 3.2 to Triton 3.4: in particular, strides (Triton 3.2 version doesn't accept non-contiguous inputs, so we just infer contiguous strides in Triton 3.4) and element_ty (Triton 3.4 doesn't support this arg, so in Triton 3.2 we just infer it from base_ptr). Second, this updates mm.py to use the TMA shim. mm_scaled_grouped.py still needs to be updated, but requires some work around recent changes in #150944. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
davidberard98
added a commit
that referenced
this pull request
Jun 9, 2025
Follow-up to #154858. Triton 3.4 will provide a different API for TMA compared to Triton 3.3; the TMA shim in triton_helpers dispatches to the correct API. First, this refactors the TMA shim to drop args that aren't supported from Triton 3.2 to Triton 3.4: in particular, strides (Triton 3.2 version doesn't accept non-contiguous inputs, so we just infer contiguous strides in Triton 3.4) and element_ty (Triton 3.4 doesn't support this arg, so in Triton 3.2 we just infer it from base_ptr). Second, this updates mm.py to use the TMA shim. mm_scaled_grouped.py still needs to be updated, but requires some work around recent changes in #150944. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
davidberard98
added a commit
that referenced
this pull request
Jun 9, 2025
… mm, mm_scaled_grouped support" Follow-up to #154858. Triton 3.4 will provide a different API for TMA compared to Triton 3.3; the TMA shim in triton_helpers dispatches to the correct API. First, this refactors the TMA shim to drop args that aren't supported from Triton 3.2 to Triton 3.4: in particular, strides (Triton 3.2 version doesn't accept non-contiguous inputs, so we just infer contiguous strides in Triton 3.4) and element_ty (Triton 3.4 doesn't support this arg, so in Triton 3.2 we just infer it from base_ptr). Second, this updates mm.py & mm_scaled_grouped.py to use the TMA shim. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
davidberard98
added a commit
that referenced
this pull request
Jun 9, 2025
…rouped support" Follow-up to #154858. Triton 3.4 will provide a different API for TMA compared to Triton 3.3; the TMA shim in triton_helpers dispatches to the correct API. First, this refactors the TMA shim to drop args that aren't supported from Triton 3.2 to Triton 3.4: in particular, strides (Triton 3.2 version doesn't accept non-contiguous inputs, so we just infer contiguous strides in Triton 3.4) and element_ty (Triton 3.4 doesn't support this arg, so in Triton 3.2 we just infer it from base_ptr). Second, this updates mm.py & mm_scaled_grouped.py to use the TMA shim. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
pytorchmergebot
pushed a commit
that referenced
this pull request
Jun 10, 2025
…ort (#155182) Follow-up to #154858. Triton 3.4 will provide a different API for TMA compared to Triton 3.3; the TMA shim in triton_helpers dispatches to the correct API. First, this refactors the TMA shim to drop args that aren't supported from Triton 3.2 to Triton 3.4: in particular, strides (Triton 3.2 version doesn't accept non-contiguous inputs, so we just infer contiguous strides in Triton 3.4) and element_ty (Triton 3.4 doesn't support this arg, so in Triton 3.2 we just infer it from base_ptr). Second, this updates mm.py & mm_scaled_grouped.py to use the TMA shim. Differential Revision: [D76318784](https://our.internmc.facebook.com/intern/diff/D76318784) Pull Request resolved: #155182 Approved by: https://github.com/drisspg
pytorchmergebot
pushed a commit
that referenced
this pull request
Jun 11, 2025
…#154858)" (#155640) This reverts commit ea7b233. It fails internal tests in fbcode, but they weren't running so we didn't get signal Reverting w/ a PR/diff because the PR has been landed for ~1 week - too old to revert directly from internal. Differential Revision: [D76380887](https://our.internmc.facebook.com/intern/diff/D76380887) Pull Request resolved: #155640 Approved by: https://github.com/nmacchioni, https://github.com/danzimm
davidberard98
added a commit
that referenced
this pull request
Jun 11, 2025
…py templates" Triton 3.4 will remove the experimental TMA APIs: triton-lang/triton#6488 For mm.py templates, this PR adds support for using the new APIs when they are available (and otherwise falls back to the experimental APIs). For flex_attention, we'll remove TMA support for Triton 3.2 and 3.3 (versions of triton that don't have the new API). For mm_scaled_grouped.py, #150944 will remove TMA support for Triton 3.2. Note: we attempted this earlier with #154858, but this broke TMA usage in Triton 3.2. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov Differential Revision: [D76444471](https://our.internmc.facebook.com/intern/diff/D76444471) [ghstack-poisoned]
pytorchmergebot
pushed a commit
that referenced
this pull request
Jun 12, 2025
#155723) Triton 3.4 will remove the experimental TMA APIs: triton-lang/triton#6488 For mm.py templates, this PR adds support for using the new APIs when they are available (and otherwise falls back to the experimental APIs). For flex_attention, we'll remove TMA support for Triton 3.2 and 3.3 (versions of triton that don't have the new API). For mm_scaled_grouped.py, #150944 will remove TMA support for Triton 3.2. Note: we attempted this earlier with #154858, but this broke TMA usage in Triton 3.2. Differential Revision: [D76444471](https://our.internmc.facebook.com/intern/diff/D76444471) Pull Request resolved: #155723 Approved by: https://github.com/NikhilAPatel
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Stack from ghstack (oldest at bottom):
Triton 3.4 will remove the experimental TMA apis: triton-lang/triton#6488
To allow compatibility across different triton versions, we implement a shim layer which calls the new API if available, and otherwise falls back to the experimental API.
Test:
python test/inductor/test_flex_attention.py TestFlexAttentionCUDA.test_GQA_causal_mask_cudawhich previously fails w/ triton-lang/tritoncda4229558c5dca7f7c4734bedd3e596ebcae0b8, but now passes.Note: we'll need to apply this for other things in inductor, this just does it for flex attention.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov