Expand Triton autotune configs for MoE FP8 kernels to improve AMD GPU performance#3952
Conversation
…kernels The existing autotune configs for the MoE training FP8 kernels use a single configuration each (e.g., num_warps=4, num_stages=4, one block size), which prevents Triton's autotuner from finding better configs for different hardware targets. Expand the search space to cover: - Multiple num_warps values (4, 8) to better saturate both NVIDIA (warp size 32) and AMD (wavefront size 64) GPU compute units - Multiple num_stages values for software pipelining flexibility across different cache hierarchies - Multiple block sizes to adapt to varying matrix dimensions This is complementary to PR pytorch#3945 (relaxed atomics on AMDGPU) and targets the same kernels.
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3952
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 22ee5e4 with merge base cd062f2 ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks! |
Benchmark Results on AMD MI250X (gfx90a)Tested with PyTorch 2.11.0.dev20260206+rocm7.0, Triton 3.6.0, using Note: MI250X (gfx90a) does not have native FP8 hardware — FP8 operations use software emulation. The MI300X results below are more representative of real-world performance. Before (single hardcoded config per kernel):After (expanded autotune search space):Summary
The atomic kernel shows consistent 4-8% improvement on most shapes. The reduction kernel shows large gains at batch=1 (up to 2x) but some regressions at larger batches — likely due to MI250X lacking native FP8 hardware, causing the autotuner to select configs that don't translate well to software-emulated FP8. See MI300X results below for native FP8 performance. Benchmark Results on AMD MI300X (gfx942)Tested with PyTorch 2.11.0.dev20260206+rocm7.0, ROCm 7.0.2, triton-rocm 3.6.0.
|
| Kernel | Shape | Before (us) | After (us) | Speedup |
|---|---|---|---|---|
| Atomic | (1, 8192, 5120) | 246.6 | 161.6 | 1.53x |
| Atomic | (1, 5120, 8192) | 245.9 | 164.2 | 1.50x |
| Atomic | (16, 8192, 5120) | 4604.8 | 2940.5 | 1.57x |
| Atomic | (16, 5120, 8192) | 4569.0 | 2932.9 | 1.56x |
| Atomic | (128, 8192, 5120) | 37499.0 | 23637.8 | 1.59x |
| Atomic | (128, 5120, 8192) | 37569.7 | 23643.9 | 1.59x |
| Reduction | (1, 8192, 5120) | 666.7 | 302.9 | 2.20x |
| Reduction | (1, 5120, 8192) | 415.5 | 194.6 | 2.14x |
| Reduction | (16, 8192, 5120) | 1762.8 | 1453.5 | 1.21x |
| Reduction | (16, 5120, 8192) | 1558.2 | 1382.6 | 1.13x |
| Reduction | (128, 8192, 5120) | 11155.4 | 10643.1 | 1.05x |
| Reduction | (128, 5120, 8192) | 11091.3 | 10682.2 | 1.04x |
1.5-2.2x faster on the atomic kernel, 1.04-2.2x on reduction, across all Llama4 shapes. The main win comes from BLOCK_SIZE_K=64 and BLOCK_SIZE_N=256, which the autotuner picks over the original (128, 128) on MI300X.
bench_triton_fp8_per_group_rowwise_scales.py
| Shape (Mg, N) | n_groups | Method | Before (us) | After (us) | Speedup |
|---|---|---|---|---|---|
| (16640, 8192) | 1 | triton | 737.7 | 734.6 | ~1.00x |
| (16640, 8192) | 1 | triton_transpose | 700.4 | 618.5 | 1.13x |
| (16640, 8192) | 16 | triton | 650.4 | 617.6 | 1.05x |
| (16640, 8192) | 16 | triton_transpose | 565.8 | 529.9 | 1.07x |
| (16640, 8192) | 64 | triton | 574.0 | 568.2 | ~1.01x |
| (16640, 8192) | 64 | triton_transpose | 619.7 | 537.3 | 1.15x |
bench_triton_fp8_per_group_colwise_scales.py
| Shape (Mg, K) | n_groups | Before (us) | After (us) | Speedup |
|---|---|---|---|---|
| (16640, 5120) | 1 | 497.9 | 414.8 | 1.20x |
| (16640, 5120) | 16 | 292.4 | 248.4 | 1.18x |
| (16640, 5120) | 64 | 291.2 | 245.9 | 1.18x |
MI300X notes
The original expanded configs caused regressions on the scales kernels at certain n_groups values (e.g. n_groups=64 was 19% slower) because the autotuner cached a single config per M or K value, and the optimal tile size differs by n_groups. Two fixes on top of the original config expansion:
- Kept
BLOCK_SIZE_ITERfixed at 128 — varying it caused the autotuner to pick a config tuned for long inner loops (n_groups=1) that performed poorly on short loops (n_groups=64). - Added
N_GROUPSto the autotuning key so each group count gets independently tuned, eliminating cross-n_groups interference.
H100 Benchmark Results (before gating fix)Tested with PyTorch 2.7.1+cu126 + CUDA 12.6 on NVIDIA H100 80GB HBM3. Baseline (upstream/main, single config):Expanded configs (ungated, all platforms):SummaryThe The
This motivated gating the expanded configs to AMD only (see follow-up comments). |
…configs H100 benchmarks showed ~18% regression on the atomic kernel with the expanded search space. The autotuner appears to pick suboptimal configs from the larger candidate set on NVIDIA. Gate the expanded configs behind torch.version.hip so AMD gets the broader search (4-7% faster on MI250X) while NVIDIA keeps the original tuned configs.
|
Updated to gate expanded configs behind |
H100 Benchmark (after gating expanded configs to AMD only)Tested with PyTorch 2.7.1+cu126 + CUDA 12.6 on NVIDIA H100 80GB HBM3. Baseline (upstream/main, original single config):PR branch (gated: NVIDIA uses original config, AMD gets expanded):SummaryAll shapes are within noise for both
|
|
thanks for the PR, this is great! will review this soon.
Yeah, we did this to reduce compile time -> speed up test duration, just hard coded the configs the autotuner was picking from the options we tested. If you want to add autotuner configs for AMD that is fine - AMD CI only runs on when the ROCM label is added to the PR so it won't slow down CI across all of torchao. |
|
Warning: Unknown label
Please add the new label to .github/pytorch-probot.yml |
|
To add the ciflow label This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows. |
|
To add the ciflow label This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows. |
|
To add the ciflow label This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows. |
|
@brucechanglongxu i'm curious, are you planning to only fp8 rowwise training for MoE/grouped_mm, or any plans to expand to other recipes like fp8 blockwise? We only have a prototype for FP8 Blockwise training in for linear layers in TorchAO, but I am eager to land one for MoE/GroupedGEMM, so if you folks want to add support for it as part of expanding AMD coverage of popular fp8 training recipes, I will happily review and land the PRs! |
Great question! We are definitely interested in expanding beyond FP8 row-wise. FP8 blockwise for MoE/GroupedGEMM is on our radar and we are working actively on that. We'll plan to open up follow-up PRs as we make progress. Really appreciate the timely feedback and openness to reviewing! |
…ing key
Two improvements based on MI300X (gfx942) benchmarking:
1. float8_rowwise.py: Widen block size search space for AMD GPUs.
- Atomic configs: add BLOCK_SIZE_N=256 and BLOCK_SIZE_K=64
- Reduction configs: add BLOCK_SIZE_N=128, BLOCK_SIZE_K=64, and num_stages=2,4
- Yields 1.5-2.2x speedup on MI300X for the atomic kernel and
1.05-1.25x for the reduction kernel across Llama4 MoE shapes.
2. jagged_float8_scales.py: Add N_GROUPS to autotuning key for both
rowwise and colwise scales kernels. The previous key (M or K only)
caused the autotuner to cache a single config across all n_groups
values, but optimal tile sizes differ significantly by n_groups.
This eliminates cross-n_groups interference and allows each n_groups
value to independently find its best config.
|
@brucechanglongxu looks like you need to sign the CLA, see comments from the bot. once you do, the CI will pass |
@brucechanglongxu ok awesome, we have an early prototype of fp8 blockwise training here you can use as a starting point when you're ready |
PR pytorch#3952 expanded Triton autotune configurations for MoE FP8 rowwise kernels on AMD GPUs (24-36 configs gated behind torch.version.hip). Benchmarking on MI300X reveals this causes: 1. ~15% kernel regression on DeepSeek V3 shapes due to the autotuner selecting suboptimal configs from the noisy microbenchmark results 2. Non-deterministic config selection across runs 3. No measurable improvement on Llama4 shapes vs the original single config (the PR's reported gains were vs torch.compile, not vs the original Triton config) Revert to the original single config for both atomic and reduction kernels, which is near-optimal across all tested shape families. This does NOT revert other valuable changes from pytorch#3952: - N_GROUPS added to autotune key in jagged_float8_scales.py - N_GROUPS: tl.int64 type annotation fixes The jagged_float8_scales.py configs (from PR pytorch#3972) are also preserved, as they were carefully benchmarked and provide 4.3x improvement. Benchmark results on MI300X (atomic kernel, representative shapes): | Shape | Expanded (pytorch#3952) | Single (this PR) | |-------------------|------------------|-------------------| | (128, 8192, 5120) | 10.56 ms | 10.43 ms | | (128, 5120, 8192) | 10.50 ms | 10.40 ms | | (8, 2048, 1408) | 0.068 ms | 0.072 ms | | (8, 1408, 2048) | 0.069 ms | 0.078 ms | | Cold-cache overhead| 4.4s | 1.9s |
…pes (#4024) PR #3952 expanded Triton autotune configurations for MoE FP8 rowwise kernels on AMD GPUs (24-36 configs gated behind torch.version.hip). Benchmarking on MI300X reveals this causes: 1. ~15% kernel regression on DeepSeek V3 shapes due to the autotuner selecting suboptimal configs from the noisy microbenchmark results 2. Non-deterministic config selection across runs 3. No measurable improvement on Llama4 shapes vs the original single config (the PR's reported gains were vs torch.compile, not vs the original Triton config) Revert to the original single config for both atomic and reduction kernels, which is near-optimal across all tested shape families. This does NOT revert other valuable changes from #3952: - N_GROUPS added to autotune key in jagged_float8_scales.py - N_GROUPS: tl.int64 type annotation fixes The jagged_float8_scales.py configs (from PR #3972) are also preserved, as they were carefully benchmarked and provide 4.3x improvement. Benchmark results on MI300X (atomic kernel, representative shapes): | Shape | Expanded (#3952) | Single (this PR) | |-------------------|------------------|-------------------| | (128, 8192, 5120) | 10.56 ms | 10.43 ms | | (128, 5120, 8192) | 10.50 ms | 10.40 ms | | (8, 2048, 1408) | 0.068 ms | 0.072 ms | | (8, 1408, 2048) | 0.069 ms | 0.078 ms | | Cold-cache overhead| 4.4s | 1.9s |
The MoE FP8 Triton kernels in
float8_rowwise.pyandjagged_float8_scales.pyeach have a single hardcoded autotune config. This means the autotuner never actually tunes anything — it just uses the one config regardless of hardware or problem size.This PR gates an expanded autotune search space behind
torch.version.hip, so AMD gets 8-16 candidate configs per kernel while NVIDIA keeps the original single config unchanged. The original values are always included in the AMD search space, so the autotuner can only do better (or equal).This matters for AMD GPUs in particular — AMD wavefronts are 64 threads (vs 32 on NVIDIA), so the best
num_warpsand pipelining depth tend to differ. Benchmarks on MI250X show 4-15% improvement on Llama4 shapes (see comments). Initial H100 testing with an ungated expanded search space showed ~18% regression on the atomic kernel for larger shapes, which motivated the gating approach.Complements #3945 (relaxed atomics on AMDGPU), same kernel files.
Test plan
ruff check/ruff formatcleancc: @BowenBao