[mxfp8 moe training] add cuda kernel for per group padding#3998
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3998
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit 17fd81e with merge base d6d423e ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
stack-info: PR: #3998, branch: danielvegamyhre/stack/147
208bc6f to
42c5c30
Compare
| def _groups_aligned( | ||
| group_offsets: torch.Tensor, alignment_size: int = 32 | ||
| ) -> torch.Tensor: | ||
| group_sizes = torch.diff( |
There was a problem hiding this comment.
thats a fun one claude, instead of cumusm no weird sync things here right?
There was a problem hiding this comment.
deleted this and made more compile friendly by just using a pad_token_groups: bool arg in the autograd func, for the caller to specify if they need their token groups aligned or not.
was having a bad time with data dependent control flow issues, etc
| current_offset = 0 | ||
| group_start = 0 | ||
|
|
||
| for group_end in group_offsets.tolist(): |
There was a problem hiding this comment.
use @howardzhang-cv skill and see if we can also try helion here
There was a problem hiding this comment.
could be a great starter task for new team member who starts Monday, they're going to work on some training stuff and apparently they are interested in helion
stack-info: PR: #3998, branch: danielvegamyhre/stack/147
42c5c30 to
9f437f0
Compare
stack-info: PR: #3998, branch: danielvegamyhre/stack/147
9f437f0 to
2e861c9
Compare
stack-info: PR: #3998, branch: danielvegamyhre/stack/147
2e861c9 to
73cc113
Compare
stack-info: PR: #3998, branch: danielvegamyhre/stack/147
73cc113 to
dd304f4
Compare
stack-info: PR: #3998, branch: danielvegamyhre/stack/147
dd304f4 to
ab4acb0
Compare
Fix three categories of ROCm CI failures: 1. float8_tensor.py: Fix IndexError in view_as/reshape handler where range(3) was hardcoded, causing crashes on 2D tensors during DTensor.from_local(). Changed to range(len(size)). 2. blockwise FP8 kernel tests: The kernel is correct, but e4m3fnuz (ROCm) has lower dynamic range (±240) vs e4m3fn (CUDA, ±448), causing worse quantization SQNR for small-M shapes. Relaxed the SQNR threshold on ROCm (verified kernel matches reference impl). 3. MoE training: Temporarily skip expert training tests on ROCm due to per-group padding shape mismatch introduced in pytorch#3998.
* [ROCm] Fix ROCm CI failures: float8_tensor bug, SQNR threshold, MoE skip Fix three categories of ROCm CI failures: 1. float8_tensor.py: Fix IndexError in view_as/reshape handler where range(3) was hardcoded, causing crashes on 2D tensors during DTensor.from_local(). Changed to range(len(size)). 2. blockwise FP8 kernel tests: The kernel is correct, but e4m3fnuz (ROCm) has lower dynamic range (±240) vs e4m3fn (CUDA, ±448), causing worse quantization SQNR for small-M shapes. Relaxed the SQNR threshold on ROCm (verified kernel matches reference impl). 3. MoE training: Temporarily skip expert training tests on ROCm due to per-group padding shape mismatch introduced in #3998. * Skip blockwise FP8 GEMM tests on ROCm due to numerical issues Per reviewer feedback, skip the two GEMM tests on ROCm rather than using a heavily relaxed SQNR threshold (0.5 vs 28.0). The blockwise quantization kernel tests remain enabled on ROCm. * Fix is_ROCM() to return bool instead of string is_ROCM() returned `torch.version.hip` (a version string like "7.0.51831") instead of True/False. Python's `and` returns the last truthy operand, so `True and "7.0.51831"` evaluates to the string itself. This caused pytest's @pytest.mark.skipif to interpret the string as a Python expression to compile/eval, resulting in SyntaxError (Python parses "7.0" as a float literal, then ".51831" as an invalid attribute access).
Stacked PRs:
Motivation
Summary
inputs[padded_indexes, :]where the -1 index will select a padding row every time it appears.Tests
pytest test/prototype/moe_training/test_kernels.py -k padBenchmarks
CUDA is the best across all shapes:
Limitations
MXFP8 grouped mm autograd func fwd + bwd new benchmarks
We lose about 25% of the speedup for Llama4 and 50% of the speedup for DSV3 671b. Need to do the all2all dispatch padding approach to avoid this.