Skip to content

[inductor][triton pin] add support for new TMA API for mm.py templates#155723

Closed
davidberard98 wants to merge 2 commits intogh/davidberard98/372/basefrom
gh/davidberard98/372/head
Closed

[inductor][triton pin] add support for new TMA API for mm.py templates#155723
davidberard98 wants to merge 2 commits intogh/davidberard98/372/basefrom
gh/davidberard98/372/head

Conversation

@davidberard98
Copy link
Contributor

@davidberard98 davidberard98 commented Jun 11, 2025

Stack from ghstack (oldest at bottom):

Triton 3.4 will remove the experimental TMA APIs: triton-lang/triton#6488

For mm.py templates, this PR adds support for using the new APIs when they are available (and otherwise falls back to the experimental APIs).

For flex_attention, we'll remove TMA support for Triton 3.2 and 3.3 (versions of triton that don't have the new API).

For mm_scaled_grouped.py, #150944 will remove TMA support for Triton 3.2.

Note: we attempted this earlier with #154858, but this broke TMA usage in Triton 3.2.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov

Differential Revision: D76444471

@pytorch-bot
Copy link

pytorch-bot bot commented Jun 11, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/155723

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit cd4f19f with merge base 717a099 (image):

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@davidberard98 davidberard98 changed the title [inductor][triton pin] use experimental or stable TMA API for mm.py templates [inductor][triton pin] add support for new TMA API for mm.py templates Jun 11, 2025
@davidberard98
Copy link
Contributor Author

@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jun 11, 2025
@NikhilAPatel
Copy link
Contributor

I'm not entirely familiar with the mechanics of AOTI, so this may be wrong. But since AOTI doesn't support the new TMA API (due to trouble setting the allocator function), would this cause failures if our Triton version is high enough such that has_triton_stable_tma_api() returns True but we're using AOTI? Because then we'd choose the new API codepath but we wouldn't have an allocator function set.

Maybe for now, has_triton_stable_tma_api() should also make sure we're not using AOTI?

…py templates"


Triton 3.4 will remove the experimental TMA APIs: triton-lang/triton#6488

For mm.py templates, this PR adds support for using the new APIs when they are available (and otherwise falls back to the experimental APIs).

For flex_attention, we'll remove TMA support for Triton 3.2 and 3.3 (versions of triton that don't have the new API).

For mm_scaled_grouped.py, #150944 will remove TMA support for Triton 3.2.

Note: we attempted this earlier with #154858, but this broke TMA usage in Triton 3.2.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

Differential Revision: [D76444471](https://our.internmc.facebook.com/intern/diff/D76444471)

[ghstack-poisoned]
davidberard98 added a commit that referenced this pull request Jun 11, 2025
@davidberard98
Copy link
Contributor Author

@NikhilAPatel good catch - I added a check for that (and ran some really quick sanity tests to make sure that AOTI works w/ new triton, using python test/inductor/test_aot_inductor.py -k mm)

@davidberard98
Copy link
Contributor Author

@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

inner_bytes = inner_dim * dtype.itemsize
return V.graph.sizevars.statically_known_multiple_of(inner_bytes, TMA_ALIGNMENT)

if has_triton_stable_tma_api() and config.cpp_wrapper:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this disable TMA for AOTI always? Because we still want TMA for AOTI as long as its using the old API right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

has_triton_stable_tma_api() means "we'll be using the new API"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh true my bad

@davidberard98
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@davidberard98
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@davidberard98
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions github-actions bot deleted the gh/davidberard98/372/head branch July 14, 2025 02:20
facebook-github-bot pushed a commit to meta-pytorch/tritonbench that referenced this pull request Sep 16, 2025
Summary: Add utils function which verifies that `triton.tools.experimental_descriptor` is available in the current Triton version. This diff expands the work in this [GitHub PR](pytorch/pytorch#155723), which finds that Triton 3.4.0+ deprecated the experimental descriptor.

Differential Revision: D82551608
facebook-github-bot pushed a commit to meta-pytorch/tritonbench that referenced this pull request Sep 16, 2025
Summary:

Add utils function which verifies that `triton.tools.experimental_descriptor` is available in the current Triton version. This diff expands the work in this [GitHub PR](pytorch/pytorch#155723), which finds that Triton 3.4.0+ deprecated the experimental descriptor.

Differential Revision: D82551608
facebook-github-bot pushed a commit to meta-pytorch/tritonbench that referenced this pull request Sep 16, 2025
Summary:

Add utils function which verifies that `triton.tools.experimental_descriptor` is available in the current Triton version. This diff expands the work in this [GitHub PR](pytorch/pytorch#155723), which finds that Triton 3.4.0+ deprecated the experimental descriptor.

Differential Revision: D82551608
jananisriram added a commit to meta-pytorch/tritonbench that referenced this pull request Sep 16, 2025
Summary:

Add utils function which verifies that `triton.tools.experimental_descriptor` is available in the current Triton version. This diff expands the work in this [GitHub PR](pytorch/pytorch#155723), which finds that Triton 3.4.0+ deprecated the experimental descriptor.

Differential Revision: D82551608
jananisriram added a commit to meta-pytorch/tritonbench that referenced this pull request Sep 16, 2025
Summary:

Add utils function which verifies that `triton.tools.experimental_descriptor` is available in the current Triton version. This diff expands the work in this [GitHub PR](pytorch/pytorch#155723), which finds that Triton 3.4.0+ deprecated the experimental descriptor.

Differential Revision: D82551608
jananisriram added a commit to meta-pytorch/tritonbench that referenced this pull request Sep 16, 2025
Summary:

Add utils function which verifies that `triton.tools.experimental_descriptor` is available in the current Triton version. This diff expands the work in this [GitHub PR](pytorch/pytorch#155723), which finds that Triton 3.4.0+ deprecated the experimental descriptor.

Differential Revision: D82551608
facebook-github-bot pushed a commit to meta-pytorch/tritonbench that referenced this pull request Sep 16, 2025
Summary:

Add utils function which verifies that `triton.tools.experimental_descriptor` is available in the current Triton version. This diff expands the work in this [GitHub PR](pytorch/pytorch#155723), which finds that Triton 3.4.0+ deprecated the experimental descriptor.

Differential Revision: D82551608
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants