Revert expanded MoE FP8 autotune configs that regress DeepSeek V3 shapes#4024
Merged
danielvegamyhre merged 1 commit intoMar 7, 2026
Merged
Conversation
PR pytorch#3952 expanded Triton autotune configurations for MoE FP8 rowwise kernels on AMD GPUs (24-36 configs gated behind torch.version.hip). Benchmarking on MI300X reveals this causes: 1. ~15% kernel regression on DeepSeek V3 shapes due to the autotuner selecting suboptimal configs from the noisy microbenchmark results 2. Non-deterministic config selection across runs 3. No measurable improvement on Llama4 shapes vs the original single config (the PR's reported gains were vs torch.compile, not vs the original Triton config) Revert to the original single config for both atomic and reduction kernels, which is near-optimal across all tested shape families. This does NOT revert other valuable changes from pytorch#3952: - N_GROUPS added to autotune key in jagged_float8_scales.py - N_GROUPS: tl.int64 type annotation fixes The jagged_float8_scales.py configs (from PR pytorch#3972) are also preserved, as they were carefully benchmarked and provide 4.3x improvement. Benchmark results on MI300X (atomic kernel, representative shapes): | Shape | Expanded (pytorch#3952) | Single (this PR) | |-------------------|------------------|-------------------| | (128, 8192, 5120) | 10.56 ms | 10.43 ms | | (128, 5120, 8192) | 10.50 ms | 10.40 ms | | (8, 2048, 1408) | 0.068 ms | 0.072 ms | | (8, 1408, 2048) | 0.069 ms | 0.078 ms | | Cold-cache overhead| 4.4s | 1.9s |
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4024
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 12cd4e2 with merge base 5045d76 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Contributor
|
thanks @brucechanglongxu for validating this and reverting based on your findings. also for future reference please go ahead and add me as reviewer to any MoE training PRs directly, so i don't miss any |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
#3952 expanded the Triton autotune search space for MoE FP8 rowwise kernels on AMD GPUs (24 configs for atomic, 36 for reduction, gated behind
torch.version.hip). The reported gains were measured againsttorch.compilebaseline, but when comparing the autotuner-selected configs against the original single Triton config on MI300X, there's no measurable improvement on Llama4 shapes -- and the noisy autotuner microbenchmarks can select suboptimal configs that regress DeepSeek V3 shapes by ~15%.The autotuner is also non-deterministic (picks different "best" configs across runs for the same shape), and the large search space adds unnecessary cold-cache compile overhead (4.4s vs 1.9s).
This PR reverts to the original single hardcoded config for both atomic and reduction kernels in
float8_rowwise.py. The config works well across all tested shape families (Llama4 and DeepSeek V3).Other changes from #3952 and later PRs are intentionally preserved:
N_GROUPSautotune key addition injagged_float8_scales.pyN_GROUPS: tl.int64type fixes from [ROCM] Float8 deepseekv3_671b IntOverflow in triton kernels during training #4016jagged_float8_scales.pyconfigs from Optimize FP8 colwise scales kernel for AMD GPUs in MoE backward pass #3972 (carefully benchmarked, 4.3x improvement)Benchmark on MI300X (atomic kernel):
All within noise; no regression on either Llama4 or DeepSeek V3 shapes.
Test plan: