Skip to content

Enable blockwise FP8 training kernels on AMD GPUs (MI300/MI350)#3996

Merged
danielvegamyhre merged 3 commits into
pytorch:mainfrom
brucechanglongxu:rocm-blockwise-fp8-enablement
Mar 10, 2026
Merged

Enable blockwise FP8 training kernels on AMD GPUs (MI300/MI350)#3996
danielvegamyhre merged 3 commits into
pytorch:mainfrom
brucechanglongxu:rocm-blockwise-fp8-enablement

Conversation

@brucechanglongxu

@brucechanglongxu brucechanglongxu commented Mar 4, 2026

Copy link
Copy Markdown
Contributor

Replace hardcoded FP8 e4m3fn max (448.0) with a parameterized FP8_MAX derived from torch.finfo(dtype).max in all 5 blockwise quantization Triton kernels, their Python wrapper functions, and the 3 PyTorch reference implementations. This allows the kernels to operate with both float8_e4m3fn (NVIDIA, max=448) and float8_e4m3fnuz (AMD MI300, max=240).

kernels.py:

  • Add FP8_MAX as a tl.constexpr parameter to triton_fp8_blockwise_act_quant_lhs_kernel, triton_fp8_blockwise_act_quant_rhs_kernel, triton_fp8_blockwise_act_quant_transposed_lhs_kernel, triton_fp8_blockwise_weight_quant_rhs_kernel, and triton_fp8_blockwise_weight_quant_transposed_rhs_kernel. Each kernel previously had max_fp8_e4m3 = 448.0 / min_fp8_e4m3 = -448.0 inline; these are replaced with the passed-in FP8_MAX and -FP8_MAX.
  • In the 5 wrapper functions, compute fp8_max = torch.finfo(dtype).max and forward it to the kernel call. Widen the dtype assertion from [torch.float8_e4m3fn] to {torch.float8_e4m3fn, torch.float8_e4m3fnuz}. Default dtype parameter changed from torch.float8_e4m3fn to the platform-aware e4m3_dtype (from torchao.float8.config).
  • In the 3 reference implementations (torch_blockwise_scale_act_quant_lhs, torch_blockwise_scale_act_quant_rhs, torch_blockwise_scale_weight_quant), replace hardcoded torch.finfo(torch.float8_e4m3fn) with torch.finfo(dtype) and cast outputs to the passed dtype instead of torch.float8_e4m3fn.

test_blockwise_kernels.py:

  • Replace is_sm_at_least_90() capability gate with is_sm_at_least_90() || is_MI300() || is_MI350() across all 7 tests.
  • Replace hardcoded torch.float8_e4m3fn parametrize values with e4m3_dtype.
  • Remove @skip_if_rocm decorators from the 5 quantization kernel tests.

Benchmark Results (AMD Instinct MI300X)

Environment: PyTorch 2.9.1+rocm7.2.0, Triton 3.5.1+rocm7.2.0, single MI300X GPU

Correctness

All 9 test configurations produce bit-identical results between old (per-expert loop) and new (grouped GEMM kernel) paths (max_diff = 0.0).

GEMM Kernel Only: per-expert loop (old) vs grouped kernel (new)

E M K N Old (ms) New (ms) Speedup Old TFLOPS New TFLOPS
8 2048 1024 1024 2.503 0.227 11.03x 13.7 151.4
8 4096 2048 2048 2.798 0.817 3.42x 98.2 336.3
8 4096 4096 4096 7.026 4.139 1.70x 156.5 265.7
8 8192 4096 4096 11.157 9.026 1.24x 197.1 243.6
16 4096 2048 2048 5.149 0.794 6.49x 106.8 692.6
16 8192 4096 4096 13.708 7.461 1.84x 320.8 589.4
8 16384 4096 4096 21.724 18.096 1.20x 202.5 243.0
8 4096 5120 5120 12.629 5.693 2.22x 136.0 301.8
8 16640 5120 8192 55.225 40.689 1.36x 202.2 274.4

Full Forward: old Triton vs new Triton vs BF16 baseline

E M K N Old (ms) New (ms) BF16 (ms) New/Old New/BF16
8 2048 1024 1024 2.150 0.353 0.430 6.09x 1.22x
8 4096 2048 2048 2.968 1.118 0.480 2.66x 0.43x
8 4096 4096 4096 5.093 4.070 0.828 1.25x 0.20x
8 8192 4096 4096 8.105 7.700 2.167 1.05x 0.28x
16 4096 2048 2048 10.398 1.714 1.512 6.07x 0.88x
16 8192 4096 4096 19.715 8.302 1.832 2.37x 0.22x
8 16384 4096 4096 19.117 15.760 5.514 1.21x 0.35x
8 4096 5120 5120 15.924 8.565 2.432 1.86x 0.28x
8 16640 5120 8192 47.771 32.276 4.383 1.48x 0.14x

Forward+Backward (new Triton path, end-to-end)

E M K N Fwd+Bwd (ms) TFLOPS
8 2048 1024 1024 4.149 24.8
8 4096 2048 2048 5.932 139.0
8 4096 4096 4096 15.226 216.6
8 8192 4096 4096 29.026 227.3
16 4096 2048 2048 15.334 107.6
16 8192 4096 4096 33.078 398.9

Key takeaways:

  • The new grouped GEMM kernel provides 1.2x-11x speedup over the old per-expert loop at the kernel level, with the largest gains on workloads with many experts and smaller per-expert M.
  • End-to-end forward speedup is 1.05x-6.1x over the old path (quantization overhead is now the dominant cost at larger sizes).
  • The Triton FP8 GEMM is currently slower than BF16 rocBLAS because it does not yet use hardware FP8 matrix cores; the benefit comes from reduced memory traffic which will be more impactful at scale.

cc: @BowenBao

Replace hardcoded FP8 e4m3fn max (448.0) with a parameterized FP8_MAX
derived from torch.finfo(dtype).max in all 5 Triton JIT kernels, their
wrapper functions, and the 3 PyTorch reference implementations. This
allows the kernels to work with both float8_e4m3fn (NVIDIA, max=448)
and float8_e4m3fnuz (AMD MI300, max=240).

Update test capability gates from is_sm_at_least_90() to also accept
MI300/MI350, and replace hardcoded torch.float8_e4m3fn test parameters
with the platform-aware e4m3_dtype from torchao.float8.config.
@pytorch-bot

pytorch-bot Bot commented Mar 4, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3996

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 66086bd with merge base f04500f (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 4, 2026
@danielvegamyhre danielvegamyhre self-requested a review March 4, 2026 23:19
@danielvegamyhre danielvegamyhre added the module: training quantize_ api training flow label Mar 4, 2026
@danielvegamyhre

Copy link
Copy Markdown
Contributor

looks good @brucechanglongxu please fix the linter issue though, thanks!

Remove stray blank line between third-party and first-party imports
to satisfy ruff's import block formatting rules.
EPS: tl.constexpr,
FP8_MAX: tl.constexpr,
):
"""

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add this docstring back, looks like it may have been deleted by accident?

@danielvegamyhre danielvegamyhre left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just one minor comment to address

@brucechanglongxu

Copy link
Copy Markdown
Contributor Author

@danielvegamyhre This PR is approved and all CI checks are passing. Could you merge when you get a chance? Thanks!

@danielvegamyhre danielvegamyhre merged commit 629e25d into pytorch:main Mar 10, 2026
19 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: training quantize_ api training flow

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants