Skip to content

[inductor][triton pin] TMA shim refactor & mm, mm_scaled_grouped support#155182

Closed
davidberard98 wants to merge 7 commits intogh/davidberard98/366/basefrom
gh/davidberard98/366/head
Closed

[inductor][triton pin] TMA shim refactor & mm, mm_scaled_grouped support#155182
davidberard98 wants to merge 7 commits intogh/davidberard98/366/basefrom
gh/davidberard98/366/head

Conversation

@davidberard98
Copy link
Contributor

@davidberard98 davidberard98 commented Jun 5, 2025

Stack from ghstack (oldest at bottom):

Follow-up to #154858.

Triton 3.4 will provide a different API for TMA compared to Triton 3.3; the TMA shim in triton_helpers dispatches to the correct API.

First, this refactors the TMA shim to drop args that aren't supported from Triton 3.2 to Triton 3.4: in particular, strides (Triton 3.2 version doesn't accept non-contiguous inputs, so we just infer contiguous strides in Triton 3.4) and element_ty (Triton 3.4 doesn't support this arg, so in Triton 3.2 we just infer it from base_ptr).

Second, this updates mm.py & mm_scaled_grouped.py to use the TMA shim.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov

Differential Revision: D76318784

@pytorch-bot
Copy link

pytorch-bot bot commented Jun 5, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/155182

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (4 Unrelated Failures)

As of commit 95b6273 with merge base e125970 (image):

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Follow-up to #154858.

Triton 3.4 will provide a different API for TMA compared to Triton 3.3; the TMA shim in triton_helpers dispatches to the correct API.

This PR updates the TMA usage in mm.py and mm_scaled_grouped.py to use the TMA shim so that they will work with either Triton version.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
davidberard98 added a commit that referenced this pull request Jun 5, 2025
Follow-up to #154858.

Triton 3.4 will provide a different API for TMA compared to Triton 3.3; the TMA shim in triton_helpers dispatches to the correct API.

This PR updates the TMA usage in mm.py and mm_scaled_grouped.py to use the TMA shim so that they will work with either Triton version.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
Follow-up to #154858.

Triton 3.4 will provide a different API for TMA compared to Triton 3.3; the TMA shim in triton_helpers dispatches to the correct API.

This PR updates the TMA usage in mm.py and mm_scaled_grouped.py to use the TMA shim so that they will work with either Triton version.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
Follow-up to #154858.

Triton 3.4 will provide a different API for TMA compared to Triton 3.3; the TMA shim in triton_helpers dispatches to the correct API.

This PR updates the TMA usage in mm.py and mm_scaled_grouped.py to use the TMA shim so that they will work with either Triton version.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
davidberard98 added a commit that referenced this pull request Jun 9, 2025
* fix flex_attention usage
* infer element_ty, strides instead of passing them explicitly

ghstack-source-id: 01d9b14
Pull Request resolved: #155182
@davidberard98 davidberard98 changed the title [inductor][triton pin] TMA shim for mm, mm_scaled_grouped [inductor][triton pin] TMA shim refactor & mm.py support Jun 9, 2025
@davidberard98 davidberard98 added the ciflow/trunk Trigger trunk jobs on your pull request label Jun 9, 2025
Follow-up to #154858.

Triton 3.4 will provide a different API for TMA compared to Triton 3.3; the TMA shim in triton_helpers dispatches to the correct API.

First, this refactors the TMA shim to drop args that aren't supported from Triton 3.2 to Triton 3.4: in particular, strides (Triton 3.2 version doesn't accept non-contiguous inputs, so we just infer contiguous strides in Triton 3.4) and element_ty (Triton 3.4 doesn't support this arg, so in Triton 3.2 we just infer it from base_ptr).

Second, this updates mm.py to use the TMA shim.

mm_scaled_grouped.py still needs to be updated, but requires some work around recent changes in #150944.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
@davidberard98 davidberard98 changed the title [inductor][triton pin] TMA shim refactor & mm.py support [inductor][triton pin] TMA shim refactor & mm, mm_scaled_grouped support Jun 9, 2025
…rouped support"


Follow-up to #154858.

Triton 3.4 will provide a different API for TMA compared to Triton 3.3; the TMA shim in triton_helpers dispatches to the correct API.

First, this refactors the TMA shim to drop args that aren't supported from Triton 3.2 to Triton 3.4: in particular, strides (Triton 3.2 version doesn't accept non-contiguous inputs, so we just infer contiguous strides in Triton 3.4) and element_ty (Triton 3.4 doesn't support this arg, so in Triton 3.2 we just infer it from base_ptr).

Second, this updates mm.py & mm_scaled_grouped.py to use the TMA shim.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
davidberard98 added a commit that referenced this pull request Jun 9, 2025
* fix flex_attention usage
* infer element_ty, strides instead of passing them explicitly

ghstack-source-id: de50376
Pull Request resolved: #155182
@davidberard98 davidberard98 marked this pull request as ready for review June 10, 2025 00:37
Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems okay, can you make a tracking issue. Since we want this for GroupedGemm, tbh I am just not sure when we are going to able to get rid of 3.2

@davidberard98
Copy link
Contributor Author

#155519 for 3.2 deprecation stuff (cc @drisspg, making sure this is what you want a tracking task for)

@davidberard98
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@davidberard98
Copy link
Contributor Author

@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

2 similar comments
@davidberard98
Copy link
Contributor Author

@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@davidberard98
Copy link
Contributor Author

@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@davidberard98
Copy link
Contributor Author

@pytorchbot revert -m "fails on triton 3.2 (internally)" -c ghfirst

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

pytorchmergebot added a commit that referenced this pull request Jun 10, 2025
…ped support (#155182)"

This reverts commit b07725a.

Reverted #155182 on behalf of https://github.com/davidberard98 due to fails on triton 3.2 (internally) ([comment](#155182 (comment)))
@pytorchmergebot
Copy link
Collaborator

@davidberard98 your PR has been successfully reverted.

@pytorchmergebot pytorchmergebot added Reverted ci-no-td Do not run TD on this PR labels Jun 10, 2025
@github-actions github-actions bot deleted the gh/davidberard98/366/head branch July 13, 2025 02:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants