Enable blockwise FP8 training kernels on AMD GPUs (MI300/MI350)#3996
Merged
danielvegamyhre merged 3 commits intoMar 10, 2026
Merged
Conversation
Replace hardcoded FP8 e4m3fn max (448.0) with a parameterized FP8_MAX derived from torch.finfo(dtype).max in all 5 Triton JIT kernels, their wrapper functions, and the 3 PyTorch reference implementations. This allows the kernels to work with both float8_e4m3fn (NVIDIA, max=448) and float8_e4m3fnuz (AMD MI300, max=240). Update test capability gates from is_sm_at_least_90() to also accept MI300/MI350, and replace hardcoded torch.float8_e4m3fn test parameters with the platform-aware e4m3_dtype from torchao.float8.config.
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3996
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 66086bd with merge base f04500f ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Contributor
|
looks good @brucechanglongxu please fix the linter issue though, thanks! |
Remove stray blank line between third-party and first-party imports to satisfy ruff's import block formatting rules.
| EPS: tl.constexpr, | ||
| FP8_MAX: tl.constexpr, | ||
| ): | ||
| """ |
Contributor
There was a problem hiding this comment.
can you add this docstring back, looks like it may have been deleted by accident?
danielvegamyhre
approved these changes
Mar 7, 2026
danielvegamyhre
left a comment
Contributor
There was a problem hiding this comment.
LGTM, just one minor comment to address
Contributor
Author
|
@danielvegamyhre This PR is approved and all CI checks are passing. Could you merge when you get a chance? Thanks! |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Replace hardcoded FP8 e4m3fn max (448.0) with a parameterized FP8_MAX derived from torch.finfo(dtype).max in all 5 blockwise quantization Triton kernels, their Python wrapper functions, and the 3 PyTorch reference implementations. This allows the kernels to operate with both float8_e4m3fn (NVIDIA, max=448) and float8_e4m3fnuz (AMD MI300, max=240).
kernels.py:
max_fp8_e4m3 = 448.0/min_fp8_e4m3 = -448.0inline; these are replaced with the passed-in FP8_MAX and -FP8_MAX.[torch.float8_e4m3fn]to{torch.float8_e4m3fn, torch.float8_e4m3fnuz}. Default dtype parameter changed from torch.float8_e4m3fn to the platform-aware e4m3_dtype (from torchao.float8.config).test_blockwise_kernels.py:
Benchmark Results (AMD Instinct MI300X)
Environment: PyTorch 2.9.1+rocm7.2.0, Triton 3.5.1+rocm7.2.0, single MI300X GPU
Correctness
All 9 test configurations produce bit-identical results between old (per-expert loop) and new (grouped GEMM kernel) paths (max_diff = 0.0).
GEMM Kernel Only: per-expert loop (old) vs grouped kernel (new)
Full Forward: old Triton vs new Triton vs BF16 baseline
Forward+Backward (new Triton path, end-to-end)
Key takeaways:
cc: @BowenBao