[inductor][triton pin] TMA shim refactor & mm, mm_scaled_grouped support#155182
[inductor][triton pin] TMA shim refactor & mm, mm_scaled_grouped support#155182davidberard98 wants to merge 7 commits intogh/davidberard98/366/basefrom
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/155182
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (4 Unrelated Failures)As of commit 95b6273 with merge base e125970 ( BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Follow-up to #154858. Triton 3.4 will provide a different API for TMA compared to Triton 3.3; the TMA shim in triton_helpers dispatches to the correct API. This PR updates the TMA usage in mm.py and mm_scaled_grouped.py to use the TMA shim so that they will work with either Triton version. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
Follow-up to #154858. Triton 3.4 will provide a different API for TMA compared to Triton 3.3; the TMA shim in triton_helpers dispatches to the correct API. This PR updates the TMA usage in mm.py and mm_scaled_grouped.py to use the TMA shim so that they will work with either Triton version. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
Follow-up to #154858. Triton 3.4 will provide a different API for TMA compared to Triton 3.3; the TMA shim in triton_helpers dispatches to the correct API. This PR updates the TMA usage in mm.py and mm_scaled_grouped.py to use the TMA shim so that they will work with either Triton version. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
Follow-up to #154858. Triton 3.4 will provide a different API for TMA compared to Triton 3.3; the TMA shim in triton_helpers dispatches to the correct API. This PR updates the TMA usage in mm.py and mm_scaled_grouped.py to use the TMA shim so that they will work with either Triton version. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
Follow-up to #154858. Triton 3.4 will provide a different API for TMA compared to Triton 3.3; the TMA shim in triton_helpers dispatches to the correct API. First, this refactors the TMA shim to drop args that aren't supported from Triton 3.2 to Triton 3.4: in particular, strides (Triton 3.2 version doesn't accept non-contiguous inputs, so we just infer contiguous strides in Triton 3.4) and element_ty (Triton 3.4 doesn't support this arg, so in Triton 3.2 we just infer it from base_ptr). Second, this updates mm.py to use the TMA shim. mm_scaled_grouped.py still needs to be updated, but requires some work around recent changes in #150944. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
…rouped support" Follow-up to #154858. Triton 3.4 will provide a different API for TMA compared to Triton 3.3; the TMA shim in triton_helpers dispatches to the correct API. First, this refactors the TMA shim to drop args that aren't supported from Triton 3.2 to Triton 3.4: in particular, strides (Triton 3.2 version doesn't accept non-contiguous inputs, so we just infer contiguous strides in Triton 3.4) and element_ty (Triton 3.4 doesn't support this arg, so in Triton 3.2 we just infer it from base_ptr). Second, this updates mm.py & mm_scaled_grouped.py to use the TMA shim. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
2 similar comments
|
@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
|
@davidberard98 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
|
@pytorchbot revert -m "fails on triton 3.2 (internally)" -c ghfirst |
|
@pytorchbot successfully started a revert job. Check the current status here. |
…ped support (#155182)" This reverts commit b07725a. Reverted #155182 on behalf of https://github.com/davidberard98 due to fails on triton 3.2 (internally) ([comment](#155182 (comment)))
|
@davidberard98 your PR has been successfully reverted. |
Stack from ghstack (oldest at bottom):
Follow-up to #154858.
Triton 3.4 will provide a different API for TMA compared to Triton 3.3; the TMA shim in triton_helpers dispatches to the correct API.
First, this refactors the TMA shim to drop args that aren't supported from Triton 3.2 to Triton 3.4: in particular, strides (Triton 3.2 version doesn't accept non-contiguous inputs, so we just infer contiguous strides in Triton 3.4) and element_ty (Triton 3.4 doesn't support this arg, so in Triton 3.2 we just infer it from base_ptr).
Second, this updates mm.py & mm_scaled_grouped.py to use the TMA shim.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov
Differential Revision: D76318784