[moe training] Optimize triton_fp8_per_group_colwise_scales for AMDGPU#4113
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4113
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 6564114 with merge base 2a8fa55 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
To add the ciflow label This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows. |
|
@wenchenvincent thanks for your work on this! i want to hold off on landing more fp8 rowwise ROCM changes until the CI is fixed (see #4061, cc @brucechanglongxu ) |
|
@danielvegamyhre I just found that I benchmarked the kernel with an old version of pytorch/triton. With newer Triton, there was a perf regression. I filed an issue with Triton: triton-lang/triton#9834 I will wait for the issue to be fixed or look for other means. |
00cd499 to
78a93c0
Compare
|
@danielvegamyhre I fixed the perf regression with Triton 3.6. Could you review? |
|
@danielvegamyhre I saw there were CI failures. It seemed that they were not related to the changes from this PR though. |
|
@wenchenvincent can you rebase, our CI is green / fixed now |
Key optimizations: 1. Use tl.constexpr for unit strides (STRIDE_INPUT_COL, STRIDE_OUTPUT_ROW) to restore vectorized loads on Triton 3.6. Without this, the Coalesce pass cannot infer stride=1 from runtime values, causing loads to devectorize from 4x buffer_load_dwordx4 to 32x buffer_load_ushort. 2. Add fused single-pass kernel for group sizes 256-2048. Loads all rows at once, computing amax and scaling from registers without a second HBM read. Reduces memory traffic from 5 to 3 bytes/elem. 3. Single fixed config per platform to avoid per-key autotuning D2H sync overhead (upstream convention). Benchmark on MI300X (Triton 3.6, DSv3 671B EP=8, row-major bf16): Original upstream: ~7290 us per MoE layer After optimization: ~1100 us per MoE layer (6.6x speedup) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Reflects actual usage after pytorch#3972. Also add dual kernel benchmark script from upstream. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
78a93c0 to
6564114
Compare
|
@danielvegamyhre I rebased upon main. |
Summary
triton_fp8_per_group_colwise_scaleskernel with coalesced column-major writes usingtl.transand expanded autotune search spaceBenchmark results (MI300X, row-major bf16 input)
Triton kernel speedup vs torch.compile reference (
bench_triton_fp8_per_group_colwise_scales.py):Local benchmarking of DeepSeek V3 671B EP=8 per-call breakdown (M=32768, E_local=32, tok/exp=1024):
Key optimizations
tl.transfor coalesced writes: The column-major output layout caused non-coalescedglobal_store_bytewhere consecutive SIMD lanes wrote to addresses K bytes apart. Transposing the tile through LDS before storing makes consecutive lanes writeconsecutive rows (stride 1), eliminating write amplification.
Fused single-pass kernel: When group size fits in registers (<=2048 rows), loads all rows at once, computes amax and scales from registers without a second HBM read. Reduces memory traffic from 5 to 3 bytes/element.
Expanded autotune configs: Search across tile sizes [32, 64, 128] x [64, 128, 256], warps [4, 8], and stages [2, 3] instead of a single fixed config.