Skip to content

[mxfp8 moe training] add CUDA kernel to quantize 3d tensor colwise#3002

Merged
danielvegamyhre merged 1 commit into
mainfrom
danielvegamyhre/stack/69
Sep 19, 2025
Merged

[mxfp8 moe training] add CUDA kernel to quantize 3d tensor colwise#3002
danielvegamyhre merged 1 commit into
mainfrom
danielvegamyhre/stack/69

Conversation

@danielvegamyhre

@danielvegamyhre danielvegamyhre commented Sep 14, 2025

Copy link
Copy Markdown
Contributor

Stacked PRs:


[mxfp8 moe training] add CUDA kernel to quantize 3d tensor colwise

Summary

  • This PR adds a new CUDA kernel specifically for quantizing 3d expert weights shape (E,N,K) along the N dimension and writing directly to column major format.
    • Design: I create separate input/output TMA descriptors for each expert, and process each 2d expert in parallel using the same method that the 2d dim1 quantization kernel uses. The 2d kernel achieves 85% peak memory bandwidth utilization, so hopefully we can achieve similar perf for 3d.
  • The existing methods for quantizing 3d expert weights both scale very poorly. I have verified this via benchmarking and traces (see previous PR), and hypothesize that it is due to required .contiguous() calls:
    • Using to_mx only scales along the last dim and requires contiguos inputs. So this requires transposing contiguous tensor (E,N,K) -> (E,K,N) then calling .contiguous() to scale along the N dim (needed for backwards)
    • Using the existing CUDA kernel for casting along dim1 is possible, by treating the 3d input tensor as a 2d tensor of shape (E*N, K). However, this produces a 2d output tensor in column major format, and there is no way to reshape and restride the tensor to be 3d again AND preserve the column major format, such that numerics are preserved. Thus, we have to transform the output to column major afterwards, requiring a .contiguous() call.

Test plan

  • Added tests that verify numerical accuracy

Kernel microbenchmarks

Perf is decent for large E and abysmal for small E. Need to investigate this.

Update (9/15): NCU shows 3d kernel operating on (2,8192,5120) tensor is actually compute bound (??)

input_shape         to_mx_us    cuda_2d_us    cuda_3d_us    to_mx_gbps    cuda_2d_gbps    cuda_3d_gbps
----------------  ----------  ------------  ------------  ------------  --------------  --------------
(1, 8192, 5120)      117.92         69.776       242.112      1078.19          1822.11         525.128
(2, 8192, 5120)      431.264       105.536       249.728       589.615         2409.41        1018.23
(4, 8192, 5120)      848.992       195.584       297.376       599.015         2600.21        1710.16
(8, 8192, 5120)     1682.59        379.904       412.992       604.495         2677.3         2462.8
(16, 8192, 5120)    3350.53        775.984       615.536       607.139         2621.49        3304.82
(64, 8192, 5120)   13352          3150.66       1959.71        609.418         2582.62        4152.11

@pytorch-bot

pytorch-bot Bot commented Sep 14, 2025

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3002

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit b64758e with merge base f75b251 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

danielvegamyhre added a commit that referenced this pull request Sep 14, 2025
stack-info: PR: #3002, branch: danielvegamyhre/stack/69
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 14, 2025
@danielvegamyhre danielvegamyhre added mx moe module: not user facing Use this tag if you don't want this PR to show up in release notes labels Sep 14, 2025
@danielvegamyhre danielvegamyhre marked this pull request as draft September 14, 2025 23:49
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/68 to main September 14, 2025 23:51
danielvegamyhre added a commit that referenced this pull request Sep 14, 2025
stack-info: PR: #3002, branch: danielvegamyhre/stack/69
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/68 September 14, 2025 23:51
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/68 to main September 15, 2025 00:16
danielvegamyhre added a commit that referenced this pull request Sep 15, 2025
stack-info: PR: #3002, branch: danielvegamyhre/stack/69
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/68 September 15, 2025 00:16
@danielvegamyhre danielvegamyhre marked this pull request as ready for review September 15, 2025 00:17
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/68 to main September 15, 2025 02:18
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/68 September 15, 2025 02:18
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/68 to main September 15, 2025 02:19
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/68 September 15, 2025 02:19
@danielvegamyhre

danielvegamyhre commented Sep 15, 2025

Copy link
Copy Markdown
Contributor Author

@slayton58 @ngimel i would be curious to get your thoughts on ways to improve this kernel for quantizing 3d expert weights (E,N,K) along the N dim, where weights are contiguous. It uses nearly identical logic to the 2d dim1 cast kernel (which achieves ~85% mem bw utilization), yet the perf is much worse (~8% to 62% peak mem bw, depending on input size - see benchmarks in PR description).

I think the culprit might be how i'm allocating all the TMA descriptors and passing them in, the overhead might be too much for small E? NCU has not flagged anything particularly helpful so far. Strangely, for E=2 it shows the kernel is compute bound with 78% compute throughput % and 38% memory bandwidth %.

Additional context: torch.compile and handwritten triton kernels were both slow for mxfp8 quant for RHS operands where we scale colwise (32x1 granularity) e.g., (triton hit 56% peak mem bw). So I added a CUDA kernel here which I derived from a TE kernel which achieves ~85% peak mem bw (#2513). Basically we stripped out internal TE types, added support for different scale calculation modes (floor, rceil) to align with torchao numerics, then resolved some perf issues resulting from those changes to get reasonable perf.

Now, I'm finding quantizing 3d expert weights along dim1 is scaling extremely poorly as number of experts increases (see this PR's description for details, and see #2999 for benchmarks). So I added a similar CUDA kernel to our mxfp8_cuda extension specifically for quantizing 3d expert weights colwise and writing directly to col major format we need it in.

The first approach I tried was just updating the 2d kernel to handle 3d tensors by treating it as a 2d tensor of shape (E*N, K) but the coordinate mapping / pointer arithemetic became a complicated mess that wasn't working. So I made a new kernel, that is similar to the 2d kernel but passes in separate input/output TMA descriptors for each expert, then the kernel operates on each 2d expert with logical separation, in parallel.

@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/68 to main September 15, 2025 20:38
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/68 September 15, 2025 20:38
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/68 to main September 15, 2025 21:02
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/68 September 17, 2025 15:28
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/68 to main September 17, 2025 15:47
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/68 September 17, 2025 15:48
danielvegamyhre added a commit that referenced this pull request Sep 17, 2025
stack-info: PR: #3002, branch: danielvegamyhre/stack/69
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/68 to main September 17, 2025 16:19
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/68 September 17, 2025 16:19
danielvegamyhre added a commit that referenced this pull request Sep 17, 2025
stack-info: PR: #3002, branch: danielvegamyhre/stack/69
danielvegamyhre added a commit that referenced this pull request Sep 17, 2025
stack-info: PR: #3002, branch: danielvegamyhre/stack/69
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/68 to main September 17, 2025 16:25
stack-info: PR: #3002, branch: danielvegamyhre/stack/69
@ngimel

ngimel commented Sep 19, 2025

Copy link
Copy Markdown

Discussed offline, likely culprit is cudaMallocManaged calls on the hot path https://github.com/pytorch/ao/pull/3002/files#diff-7ddc6623d9efea4ee4f4bdb3cdd7ef16ec3d3bc8bc974be85311125e464efb3dR1326, we should be able to create just a single descriptor and do the loads using the correct offsets.

@danielvegamyhre

danielvegamyhre commented Sep 19, 2025

Copy link
Copy Markdown
Contributor Author

Discussed offline, likely culprit is cudaMallocManaged calls on the hot path https://github.com/pytorch/ao/pull/3002/files#diff-7ddc6623d9efea4ee4f4bdb3cdd7ef16ec3d3bc8bc974be85311125e464efb3dR1326, we should be able to create just a single descriptor and do the loads using the correct offsets.

I've been trying this today, progress so far is:

  • Scales for all experts are correct
  • Quantized data is correct for the expert=0 subtensor, but is all 0s for all other experts

This indicates the new input TMA descriptor and async loads are working properly, but the output TMA descriptor and/or async stores are not.

I think this is because the input tensor is in simple row major format, which can easily be represented in a TMA descriptor with shape (E*N, K) with stride K. However, the output data needs to be in "column major PER expert" format, so strides (N*K, 1, N). The 2d output TMA descriptor + ptx::cp_async_bulk_tensor_2d_global_to_shared do not seem capable of representing this layout so far (or could be a skill issue on my part, haha)

I will push a WIP PR on top of this stack to show the difference between this multiple-tma-descriptor approach, which is at least functionally correct, versus the single tma descriptor approach.

@slayton58

Copy link
Copy Markdown

@danielvegamyhre I think there's a couple of options, can we do something like:

# out : [E, N, K], stride: [N*K, 1, N]#
out.transpose(2,1) # Now [E, K, N], stride: [N*K, N, 1], representable by row-major TMA descriptor (2d or otherwise)
# compute
out.transpose(1, 2) # Back to original form

or, we can try using 3d TMAs directly - ptx::cp_async_bulk_tensor_3d_global_to_shared is the relevant read invocation, and a 3d descriptor would have to be created for this.
or (finally) we can ignore TMA writes, and use regular global stores (so STG) - we shouldn't need the async, pipeline-able part of the TMA instruction.

@danielvegamyhre

Copy link
Copy Markdown
Contributor Author

@slayton58 row major -> transpose to per expert col major is a great idea! Trying it now.

or, we can try using 3d TMAs directly

Yeah I've considered this, it is probably the "proper" way to do it but would require a larger refactor, hopefully the transpose method works.

@drisspg

drisspg commented Sep 19, 2025

Copy link
Copy Markdown
Contributor

can we just merge the top of stack w/ this pR

size_t SCALE_DIM_X, ScaleCalculationMode ScalingMode>
__global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK)
mxfp8_quantize_kernel_3d(
const CUtensorMap* tensor_maps_input,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to link against the cuda drivers in order to build this?

I think we fixed this at the top of the stack?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean in the create tensor map functions? the other examples I saw do reference the cuda driver to get the cuTensorMapEncodeTiled function pointer, so I did the same for 3d tensor map creation.

@danielvegamyhre danielvegamyhre merged commit f210443 into main Sep 19, 2025
18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: not user facing Use this tag if you don't want this PR to show up in release notes moe mx

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants