Skip to content

[mxfp8 moe training] add cuda kernel for per group padding#3998

Merged
danielvegamyhre merged 1 commit into
mainfrom
danielvegamyhre/stack/147
Mar 10, 2026
Merged

[mxfp8 moe training] add cuda kernel for per group padding#3998
danielvegamyhre merged 1 commit into
mainfrom
danielvegamyhre/stack/147

Conversation

@danielvegamyhre

@danielvegamyhre danielvegamyhre commented Mar 5, 2026

Copy link
Copy Markdown
Contributor

Stacked PRs:


Motivation

  • Torchtitan is removing the token group padding logic since it is now only needed for fp8/mxfp8, not bf16 grouped mm
  • We have also received user feedback that torchao should handle this token group padding / data movement as it is tricky to do efficiently and many users won't be able to get a speedup using TorchAO MXFP8 without an optimized implementation of it.

Summary

  • For emulated mode, use a torch native token group padding impl with d2h sync
  • For non-emulated mode, I prototyped the following:
    • Add 3 prototype kernels to do this padding with different approaches, with tests and benchmarks
      • Approach 1: Triton two stage approach;
        • Add padding row to inputs with torch.vstack
        • Triton kernel generates indexes with padding index of -1, avoiding d2h sync
        • Torch code uses those indexes like inputs[padded_indexes, :] where the -1 index will select a padding row every time it appears.
      • Approach 2: Triton fused approach (do copy tokens and write padding data in one kernel )
      • Approach 3: CUDA kernel with 2 stage approach (precompute padded offsets in tiny kernel instead of several torch ops, then dispatch padding kernel)

Tests

  • pytest test/prototype/moe_training/test_kernels.py -k pad

Benchmarks

CUDA is the best across all shapes:

  num_tokens    dim    num_groups    torch_us    triton_us    fused_triton_us    cuda_us    torch_mem_bw_gbps    triton_mem_bw_gbps    triton_fused_mem_bw_gbps    cuda_mem_bw_gbps  triton_vs_torch    fused_triton_vs_torch    cuda_vs_torch
------------  -----  ------------  ----------  -----------  -----------------  ---------  -------------------  --------------------  --------------------------  ------------------  -----------------  -----------------------  ---------------
       16384   1536             1     436.458      439.953            417.303    320.444               346.18                343.43                      362.07              471.51  0.99x              1.05x                    1.36x
       16384   1536             4     517.7        452.393            383.254    292.423               292.42                334.64                      395.01              517.7   1.14x              1.35x                    1.77x
       16384   1536             8     536.612      466.062            368.165    286.548               282.85                325.67                      412.27              529.69  1.15x              1.46x                    1.87x
       16384   1536            16     598          465.822            367.509    286.343               255.13                327.52                      415.14              532.81  1.28x              1.63x                    2.09x
       16384   2048             1     441.902      449.814            382.631    292.373               455.89                447.87                      526.51              689.04  0.98x              1.15x                    1.51x
       16384   2048             4     484.005      461.893            401.875    309.171               417.04                437.01                      502.27              652.88  1.05x              1.20x                    1.57x
       16384   2048             8     541.357      473.953            392.233    298.03                373.83                426.99                      515.96              679.04  1.14x              1.38x                    1.82x
       16384   2048            16     585.292      460.7              390.033    293.395               347.56                441.55                      521.56              693.34  1.27x              1.50x                    1.99x
       16384   5120             1     460.036      464.683            567.816    299.987              1094.79               1083.85                      886.98             1678.89  0.99x              0.81x                    1.53x
       16384   5120             4     497.396      449.034            575.571    286.81               1014.54               1123.81                      876.74             1759.45  1.11x              0.86x                    1.73x
       16384   5120             8     526.011      455.096            612.969    300.776               961.84               1111.72                      825.39             1682.11  1.16x              0.86x                    1.75x
       16384   5120            16     607.179      484.409            630.304    296.039               837.58               1049.86                      806.85             1717.88  1.25x              0.96x                    2.05x
       16384   7168             1     465.052      459.34             613.543    295.84               1516.18               1535.03                     1149.23             2383.39  1.01x              0.76x                    1.57x
       16384   7168             4     492.211      456.083            623.67     293.7                1435.31               1549.01                     1132.77             2405.44  1.08x              0.79x                    1.68x
       16384   7168             8     533.49       465.685            665.016    293.333              1327.7                1521.01                     1065.11             2414.71  1.15x              0.80x                    1.82x
       16384   7168            16     609.1        469.449            677.495    298.956              1168.91               1516.64                     1050.91             2381.56  1.30x              0.90x                    2.04x

Limitations

  • We should update the docs to recommend Torchtitan users use HybridEP which fuses the padding into the dispatch step to avoid this extra copy (which is expensive since input activations are huge).
  • For non-torchtitan users, we can modify our Triton+Symmetric memory all2all impl to fuse padding this way as well, but it is a larger project.

MXFP8 grouped mm autograd func fwd + bwd new benchmarks

We lose about 25% of the speedup for Llama4 and 50% of the speedup for DSV3 671b. Need to do the all2all dispatch padding approach to avoid this.

M,N,K,G                  recipe                             bf16_fwd_bwd_us    scaled_fwd_bwd_us  scaled_fwd_bwd_speedup      bf16_fwd_us    scaled_fwd_us  scaled_fwd_speedup
-----------------------  -------------------------------  -----------------  -------------------  ------------------------  -------------  ---------------  --------------------
(32768, 8192, 5120, 1)   MXFP8TrainingRecipe.MXFP8_RCEIL            7274.69              4548.62  1.599x                         2033.6           1249.18   1.628x
(32768, 8192, 5120, 2)   MXFP8TrainingRecipe.MXFP8_RCEIL            7264.42              4645.89  1.564x                         2171.97          1318.75   1.647x
(128000, 8192, 5120, 1)  MXFP8TrainingRecipe.MXFP8_RCEIL           28025.1              17680.4   1.585x                         9683.04          4871.17   1.988x
(128000, 8192, 5120, 2)  MXFP8TrainingRecipe.MXFP8_RCEIL           27889.7              17515.6   1.592x                         9126.48          4927.33   1.852x
(32768, 2048, 7168, 4)   MXFP8TrainingRecipe.MXFP8_RCEIL            2493.47              2234.4   1.116x                          733.216          758.752  0.966x
(32768, 2048, 7168, 8)   MXFP8TrainingRecipe.MXFP8_RCEIL            2519.04              2359.2   1.068x                          708.544          807.872  0.877x
(128000, 2048, 7168, 4)  MXFP8TrainingRecipe.MXFP8_RCEIL           10091.2               8602.66  1.173x                         2898.9           2812.06   1.031x
(128000, 2048, 7168, 8)  MXFP8TrainingRecipe.MXFP8_RCEIL            9958.98              8097.26  1.23x                          3020.8           2861.15   1.056x

@pytorch-bot

pytorch-bot Bot commented Mar 5, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3998

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 17fd81e with merge base d6d423e (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

danielvegamyhre added a commit that referenced this pull request Mar 5, 2026
stack-info: PR: #3998, branch: danielvegamyhre/stack/147
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/147 branch from 208bc6f to 42c5c30 Compare March 5, 2026 03:18
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 5, 2026
@danielvegamyhre danielvegamyhre requested a review from drisspg March 5, 2026 03:36
def _groups_aligned(
group_offsets: torch.Tensor, alignment_size: int = 32
) -> torch.Tensor:
group_sizes = torch.diff(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thats a fun one claude, instead of cumusm no weird sync things here right?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deleted this and made more compile friendly by just using a pad_token_groups: bool arg in the autograd func, for the caller to specify if they need their token groups aligned or not.

was having a bad time with data dependent control flow issues, etc

Comment thread torchao/prototype/moe_training/mxfp8_grouped_mm.py Outdated
current_offset = 0
group_start = 0

for group_end in group_offsets.tolist():

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use @howardzhang-cv skill and see if we can also try helion here

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could be a great starter task for new team member who starts Monday, they're going to work on some training stuff and apparently they are interested in helion

@danielvegamyhre danielvegamyhre marked this pull request as draft March 5, 2026 04:44
danielvegamyhre added a commit that referenced this pull request Mar 5, 2026
stack-info: PR: #3998, branch: danielvegamyhre/stack/147
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/147 branch from 42c5c30 to 9f437f0 Compare March 5, 2026 04:44
@danielvegamyhre danielvegamyhre marked this pull request as ready for review March 5, 2026 04:44
@danielvegamyhre danielvegamyhre marked this pull request as draft March 5, 2026 04:46
danielvegamyhre added a commit that referenced this pull request Mar 5, 2026
stack-info: PR: #3998, branch: danielvegamyhre/stack/147
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/147 branch from 9f437f0 to 2e861c9 Compare March 5, 2026 04:46
@danielvegamyhre danielvegamyhre marked this pull request as ready for review March 5, 2026 04:46
@danielvegamyhre danielvegamyhre marked this pull request as draft March 6, 2026 05:32
danielvegamyhre added a commit that referenced this pull request Mar 6, 2026
stack-info: PR: #3998, branch: danielvegamyhre/stack/147
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/147 branch from 2e861c9 to 73cc113 Compare March 6, 2026 05:32
@danielvegamyhre danielvegamyhre marked this pull request as ready for review March 6, 2026 05:32
@danielvegamyhre danielvegamyhre marked this pull request as draft March 6, 2026 07:50
danielvegamyhre added a commit that referenced this pull request Mar 6, 2026
stack-info: PR: #3998, branch: danielvegamyhre/stack/147
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/147 branch from 73cc113 to dd304f4 Compare March 6, 2026 07:50
@danielvegamyhre danielvegamyhre marked this pull request as ready for review March 6, 2026 07:50
@danielvegamyhre danielvegamyhre added module: training quantize_ api training flow mx labels Mar 6, 2026
@danielvegamyhre danielvegamyhre marked this pull request as draft March 6, 2026 18:20
danielvegamyhre added a commit that referenced this pull request Mar 6, 2026
stack-info: PR: #3998, branch: danielvegamyhre/stack/147
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/147 branch from dd304f4 to ab4acb0 Compare March 6, 2026 18:20
@danielvegamyhre danielvegamyhre marked this pull request as ready for review March 6, 2026 18:21
@danielvegamyhre danielvegamyhre marked this pull request as ready for review March 6, 2026 22:43
@danielvegamyhre danielvegamyhre marked this pull request as draft March 7, 2026 00:48
@danielvegamyhre danielvegamyhre marked this pull request as ready for review March 7, 2026 00:49
@danielvegamyhre danielvegamyhre marked this pull request as draft March 7, 2026 00:58
@danielvegamyhre danielvegamyhre marked this pull request as ready for review March 7, 2026 00:58
@danielvegamyhre danielvegamyhre marked this pull request as draft March 7, 2026 01:00
@danielvegamyhre danielvegamyhre marked this pull request as ready for review March 7, 2026 01:01
@danielvegamyhre danielvegamyhre marked this pull request as draft March 7, 2026 01:18
@danielvegamyhre danielvegamyhre marked this pull request as ready for review March 7, 2026 01:18
@danielvegamyhre danielvegamyhre marked this pull request as draft March 7, 2026 02:23
@danielvegamyhre danielvegamyhre marked this pull request as ready for review March 7, 2026 02:23
@danielvegamyhre danielvegamyhre marked this pull request as draft March 7, 2026 03:02
@danielvegamyhre danielvegamyhre marked this pull request as ready for review March 7, 2026 03:02
@danielvegamyhre danielvegamyhre marked this pull request as draft March 7, 2026 04:35
@danielvegamyhre danielvegamyhre marked this pull request as ready for review March 7, 2026 04:35
@danielvegamyhre danielvegamyhre marked this pull request as draft March 7, 2026 04:40
@danielvegamyhre danielvegamyhre marked this pull request as ready for review March 7, 2026 04:40
@danielvegamyhre danielvegamyhre marked this pull request as draft March 7, 2026 05:46
@danielvegamyhre danielvegamyhre marked this pull request as ready for review March 7, 2026 05:46
@danielvegamyhre danielvegamyhre marked this pull request as draft March 7, 2026 06:00
@danielvegamyhre danielvegamyhre marked this pull request as ready for review March 7, 2026 06:00
@danielvegamyhre danielvegamyhre merged commit f0d0deb into main Mar 10, 2026
44 of 46 checks passed
@danielvegamyhre danielvegamyhre changed the title [mxfp8 moe training] add triton kernel for per group padding [mxfp8 moe training] add cuda kernel for per group padding Mar 10, 2026
@danielvegamyhre danielvegamyhre added this to the MXFP8 Training milestone Mar 11, 2026
brucechanglongxu added a commit to brucechanglongxu/ao that referenced this pull request Mar 11, 2026
Fix three categories of ROCm CI failures:

1. float8_tensor.py: Fix IndexError in view_as/reshape handler where
   range(3) was hardcoded, causing crashes on 2D tensors during
   DTensor.from_local(). Changed to range(len(size)).

2. blockwise FP8 kernel tests: The kernel is correct, but e4m3fnuz
   (ROCm) has lower dynamic range (±240) vs e4m3fn (CUDA, ±448),
   causing worse quantization SQNR for small-M shapes. Relaxed the
   SQNR threshold on ROCm (verified kernel matches reference impl).

3. MoE training: Temporarily skip expert training tests on ROCm due
   to per-group padding shape mismatch introduced in pytorch#3998.
danielvegamyhre pushed a commit that referenced this pull request Mar 20, 2026
* [ROCm] Fix ROCm CI failures: float8_tensor bug, SQNR threshold, MoE skip

Fix three categories of ROCm CI failures:

1. float8_tensor.py: Fix IndexError in view_as/reshape handler where
   range(3) was hardcoded, causing crashes on 2D tensors during
   DTensor.from_local(). Changed to range(len(size)).

2. blockwise FP8 kernel tests: The kernel is correct, but e4m3fnuz
   (ROCm) has lower dynamic range (±240) vs e4m3fn (CUDA, ±448),
   causing worse quantization SQNR for small-M shapes. Relaxed the
   SQNR threshold on ROCm (verified kernel matches reference impl).

3. MoE training: Temporarily skip expert training tests on ROCm due
   to per-group padding shape mismatch introduced in #3998.

* Skip blockwise FP8 GEMM tests on ROCm due to numerical issues

Per reviewer feedback, skip the two GEMM tests on ROCm rather than
using a heavily relaxed SQNR threshold (0.5 vs 28.0). The blockwise
quantization kernel tests remain enabled on ROCm.

* Fix is_ROCM() to return bool instead of string

is_ROCM() returned `torch.version.hip` (a version string like
"7.0.51831") instead of True/False. Python's `and` returns the last
truthy operand, so `True and "7.0.51831"` evaluates to the string
itself. This caused pytest's @pytest.mark.skipif to interpret the
string as a Python expression to compile/eval, resulting in
SyntaxError (Python parses "7.0" as a float literal, then ".51831"
as an invalid attribute access).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: training quantize_ api training flow mx

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants