Skip to content

[moe training] Optimize triton_fp8_per_group_colwise_scales for AMDGPU#4113

Merged
danielvegamyhre merged 2 commits into
pytorch:mainfrom
wenchenvincent:feat/optimize_colwise_scales
Apr 13, 2026
Merged

[moe training] Optimize triton_fp8_per_group_colwise_scales for AMDGPU#4113
danielvegamyhre merged 2 commits into
pytorch:mainfrom
wenchenvincent:feat/optimize_colwise_scales

Conversation

@wenchenvincent

@wenchenvincent wenchenvincent commented Mar 19, 2026

Copy link
Copy Markdown
Contributor

Summary

  • Optimize triton_fp8_per_group_colwise_scales kernel with coalesced column-major writes using tl.trans and expanded autotune search space
  • Add fused single-pass kernel variant that eliminates the second HBM read by keeping data in registers, reducing memory traffic from 5 to 3 bytes/element
  • Kernel selection logic automatically picks the fused kernel when group sizes are between 256-2048 rows, falling back to the two-pass kernel otherwise
  • Change benchmark input layout to row-major to reflect actual usage after Optimize FP8 colwise scales kernel for AMD GPUs in MoE backward pass #3972

Benchmark results (MI300X, row-major bf16 input)

Triton kernel speedup vs torch.compile reference (bench_triton_fp8_per_group_colwise_scales.py):

Shape n_groups Before After Improvement
(16640, 5120) 1 0.03x 0.62x 20x
(16640, 5120) 16 0.84x 6.07x 7x
(16640, 5120) 64 5.61x 8.80x 1.6x

Local benchmarking of DeepSeek V3 671B EP=8 per-call breakdown (M=32768, E_local=32, tok/exp=1024):

pass/proj tensor shape Before (us) After (us) Speedup
wgrad/gate_up grad_output (32768, 2048) 785 136 5.8x
wgrad/gate_up A (32768, 7168) 2860 449 6.4x
wgrad/down grad_output (32768, 7168) 2860 449 6.4x
wgrad/down A (32768, 2048) 785 136 5.8x
Total per MoE layer 7290 1170 6.2x

Key optimizations

  1. tl.trans for coalesced writes: The column-major output layout caused non-coalesced global_store_byte where consecutive SIMD lanes wrote to addresses K bytes apart. Transposing the tile through LDS before storing makes consecutive lanes write
    consecutive rows (stride 1), eliminating write amplification.

  2. Fused single-pass kernel: When group size fits in registers (<=2048 rows), loads all rows at once, computes amax and scales from registers without a second HBM read. Reduces memory traffic from 5 to 3 bytes/element.

  3. Expanded autotune configs: Search across tile sizes [32, 64, 128] x [64, 128, 256], warps [4, 8], and stages [2, 3] instead of a single fixed config.

@pytorch-bot

pytorch-bot Bot commented Mar 19, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4113

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 6564114 with merge base 2a8fa55 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 19, 2026
@danielvegamyhre danielvegamyhre self-requested a review March 19, 2026 06:30
@danielvegamyhre danielvegamyhre added module: training quantize_ api training flow ciflow/rocm labels Mar 19, 2026
@pytorch-bot

pytorch-bot Bot commented Mar 19, 2026

Copy link
Copy Markdown

To add the ciflow label ciflow/rocm please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

@danielvegamyhre

Copy link
Copy Markdown
Contributor

@wenchenvincent thanks for your work on this! i want to hold off on landing more fp8 rowwise ROCM changes until the CI is fixed (see #4061, cc @brucechanglongxu )

@danielvegamyhre danielvegamyhre added this to the FP8 Rowwise Training milestone Mar 20, 2026
@wenchenvincent

Copy link
Copy Markdown
Contributor Author

@danielvegamyhre I just found that I benchmarked the kernel with an old version of pytorch/triton. With newer Triton, there was a perf regression. I filed an issue with Triton: triton-lang/triton#9834

I will wait for the issue to be fixed or look for other means.

@wenchenvincent

Copy link
Copy Markdown
Contributor Author

@danielvegamyhre I fixed the perf regression with Triton 3.6. Could you review?

@wenchenvincent

Copy link
Copy Markdown
Contributor Author

@danielvegamyhre I saw there were CI failures. It seemed that they were not related to the changes from this PR though.

@danielvegamyhre

Copy link
Copy Markdown
Contributor

@wenchenvincent can you rebase, our CI is green / fixed now

wenchenvincent and others added 2 commits April 7, 2026 18:57
Key optimizations:
1. Use tl.constexpr for unit strides (STRIDE_INPUT_COL, STRIDE_OUTPUT_ROW)
   to restore vectorized loads on Triton 3.6. Without this, the Coalesce
   pass cannot infer stride=1 from runtime values, causing loads to
   devectorize from 4x buffer_load_dwordx4 to 32x buffer_load_ushort.

2. Add fused single-pass kernel for group sizes 256-2048. Loads all rows
   at once, computing amax and scaling from registers without a second
   HBM read. Reduces memory traffic from 5 to 3 bytes/elem.

3. Single fixed config per platform to avoid per-key autotuning D2H sync
   overhead (upstream convention).

Benchmark on MI300X (Triton 3.6, DSv3 671B EP=8, row-major bf16):
  Original upstream:  ~7290 us per MoE layer
  After optimization: ~1100 us per MoE layer (6.6x speedup)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Reflects actual usage after pytorch#3972. Also add dual kernel benchmark
script from upstream.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@wenchenvincent wenchenvincent force-pushed the feat/optimize_colwise_scales branch from 78a93c0 to 6564114 Compare April 7, 2026 19:31
@wenchenvincent

Copy link
Copy Markdown
Contributor Author

@danielvegamyhre I rebased upon main.

@danielvegamyhre danielvegamyhre merged commit 10af862 into pytorch:main Apr 13, 2026
19 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: training quantize_ api training flow

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants