Skip to content

[mxfp8 moe training] use dim1 cast cuda kernel for 3d weights by reshaping to 2d#2998

Merged
danielvegamyhre merged 1 commit into
mainfrom
danielvegamyhre/stack/67
Sep 17, 2025
Merged

[mxfp8 moe training] use dim1 cast cuda kernel for 3d weights by reshaping to 2d#2998
danielvegamyhre merged 1 commit into
mainfrom
danielvegamyhre/stack/67

Conversation

@danielvegamyhre

@danielvegamyhre danielvegamyhre commented Sep 13, 2025

Copy link
Copy Markdown
Contributor

Stacked PRs:


[mxfp8 moe training] use dim1 cast cuda kernel for 3d weights by reshaping to 2d

  • Doing to_mx(B_t.contiguous()) is unspeakably slow (see perf analysis in previous PR in stack)
  • As a workaround, we can use the faster dim1 cast cuda kernel by reshaping the 3d weights to 2d, casting, then reshaping back 2d. I wasn't able to find a way to reshape the 2d, column major quantized tensor -> 3d column major tensor, so I was forced to use .t().contiguous().t() pattern, which is not ideal for perf, yet still faster than doing to_mx(B_t.contiguous()).

Next steps

  • Based on benchmarks and traces, quantizing 3d expert weights scales poorly as number of experts increases, both with the to_mx method and the 2d dim1 cast CUDA kernel method. This is likely due to the .contiguous() call required for both methods.
  • We should update the dim1 cast CUDA kernel to handle 3 inputs, writing directly to col major format, so we can avoid this expensive transformation to column major.
    • We could also update the CUTLASS grouped gemm to handle NT/TN/TT/NN layouts but I think passing in args with different memory layouts could affect kernel perf, need to think about this more.

Test plan

  • pytest test/prototype/moe_training/test_training.py

Benchmarks

Before:

A_shape        B_shape           recipe                  bf16_e2e_us    scaled_e2e_us  scaled_e2e_speedup      bf16_fwd_us    scaled_fwd_us  scaled_fwd_speedup
-------------  ----------------  --------------------  -------------  ---------------  --------------------  -------------  ---------------  --------------------
(16640, 5120)  (1, 8192, 5120)   MoEScalingType.MXFP8        4268.5           3402.75  1.254x                      1513.76          1675.81  0.903x
(16640, 5120)  (4, 8192, 5120)   MoEScalingType.MXFP8        3968.88          4282.53  0.927x                      1126.21          2222.37  0.507x
(16640, 5120)  (16, 8192, 5120)  MoEScalingType.MXFP8        4900.77          8091.55  0.606x                      1262.66          9047.7   0.14x
(16640, 5120)  (64, 8192, 5120)  MoEScalingType.MXFP8        8432.61         21453.3   0.393x                      1788.94         14476.4   0.124x

After:

A_shape        B_shape           recipe                  bf16_e2e_us    scaled_e2e_us  scaled_e2e_speedup      bf16_fwd_us    scaled_fwd_us  scaled_fwd_speedup
-------------  ----------------  --------------------  -------------  ---------------  --------------------  -------------  ---------------  --------------------
(16640, 5120)  (1, 8192, 5120)   MoEScalingType.MXFP8        4920.32          3057.79  1.609x                      1299.92          1091.58  1.191x
(16640, 5120)  (4, 8192, 5120)   MoEScalingType.MXFP8        3886.21          3402.82  1.142x                      1087.39           931.36  1.168x
(16640, 5120)  (16, 8192, 5120)  MoEScalingType.MXFP8        5769.02          5384.91  1.071x                      1411.1           1222.7   1.154x
(16640, 5120)  (64, 8192, 5120)  MoEScalingType.MXFP8        8455.23         12846.2   0.658x                      1796.1           2968.21  0.605x

@pytorch-bot

pytorch-bot Bot commented Sep 13, 2025

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2998

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 10 Pending

As of commit d29104e with merge base afe5cab (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

danielvegamyhre added a commit that referenced this pull request Sep 13, 2025
…aping to 2d

stack-info: PR: #2998, branch: danielvegamyhre/stack/67
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 13, 2025
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/66 to main September 13, 2025 05:24
danielvegamyhre added a commit that referenced this pull request Sep 13, 2025
…aping to 2d

stack-info: PR: #2998, branch: danielvegamyhre/stack/67
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/66 September 13, 2025 05:24
@danielvegamyhre danielvegamyhre added mx moe module: not user facing Use this tag if you don't want this PR to show up in release notes labels Sep 13, 2025
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/66 to main September 13, 2025 17:09
danielvegamyhre added a commit that referenced this pull request Sep 13, 2025
…aping to 2d

stack-info: PR: #2998, branch: danielvegamyhre/stack/67
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/66 September 13, 2025 17:10
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/66 to main September 13, 2025 19:20
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/66 September 13, 2025 19:20
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/66 to main September 13, 2025 19:46
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/66 September 13, 2025 19:46
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/66 to main September 13, 2025 21:06
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/66 September 13, 2025 21:06
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/66 to main September 14, 2025 23:28
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/66 September 14, 2025 23:28
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/66 to main September 14, 2025 23:51
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/66 September 14, 2025 23:51
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/66 to main September 15, 2025 00:16
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/66 September 15, 2025 00:16
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/66 to main September 16, 2025 02:51
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/66 September 16, 2025 02:51
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/66 to main September 16, 2025 02:59
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/66 September 16, 2025 02:59
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/66 to main September 16, 2025 05:06
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/66 September 16, 2025 05:07
Comment thread benchmarks/utils.py
def bench_fwd_microseconds(fn, *args, use_compile=False, fullgraph=True, **kwargs):
fn_compiled = torch.compile(fn, fullgraph=fullgraph) if use_compile else fn

def inference_fn(*args, **kwargs):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this? forward in training is not run with torch.inference_mode()

@danielvegamyhre danielvegamyhre Sep 16, 2025

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know, I just wanted to measure just the quantizations ops + grouped gemm needed to produce the forward output, without pre-computing and saving things for backward, so that I can more easily see if the quant ops for fwd or bwd are what is slow without generating a trace to look at. IIRC if I didn't use inference mode, the forward graph still precomputed stuff.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just use no_grad

@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/66 to main September 16, 2025 16:05
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/66 September 16, 2025 16:05
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/66 to main September 17, 2025 03:15
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/66 September 17, 2025 03:15
@drisspg

drisspg commented Sep 17, 2025

Copy link
Copy Markdown
Contributor

Crawling up the stack but do we need this if we are adding a kernel higher up?

@danielvegamyhre

danielvegamyhre commented Sep 17, 2025

Copy link
Copy Markdown
Contributor Author

Crawling up the stack but do we need this if we are adding a kernel higher up?

Yes, for now, because it is still currently the fastest method when E < 8. New CUDA Kernel for 3d tensors is only faster for E>8 (see relevant PR). I want to change this though and just use one kernel everywhere, would appreciate your thoughts on the kernel design. I checked NCU and it flags some minor (~15%) potential speedups from resolving bank conflicts and uncoalesced global accesses but am not confident that is really the issue, since those are present for the 2d kernel as well.

@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/66 to main September 17, 2025 15:28
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/66 September 17, 2025 15:28
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/66 to main September 17, 2025 15:47
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/66 September 17, 2025 15:48
danielvegamyhre added a commit that referenced this pull request Sep 17, 2025
…aping to 2d

stack-info: PR: #2998, branch: danielvegamyhre/stack/67
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/66 to main September 17, 2025 16:14
…aping to 2d

stack-info: PR: #2998, branch: danielvegamyhre/stack/67
@danielvegamyhre danielvegamyhre merged commit ff3ba31 into main Sep 17, 2025
9 of 18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: not user facing Use this tag if you don't want this PR to show up in release notes moe mx

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants