[mxfp8 moe training] add CUDA kernel to quantize 3d tensor colwise#3002
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3002
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit b64758e with merge base f75b251 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
stack-info: PR: #3002, branch: danielvegamyhre/stack/69
2b1b340 to
146b42a
Compare
stack-info: PR: #3002, branch: danielvegamyhre/stack/69
146b42a to
9921d5e
Compare
stack-info: PR: #3002, branch: danielvegamyhre/stack/69
9921d5e to
b3b709c
Compare
|
@slayton58 @ngimel i would be curious to get your thoughts on ways to improve this kernel for quantizing 3d expert weights (E,N,K) along the N dim, where weights are contiguous. It uses nearly identical logic to the 2d dim1 cast kernel (which achieves ~85% mem bw utilization), yet the perf is much worse (~8% to 62% peak mem bw, depending on input size - see benchmarks in PR description). I think the culprit might be how i'm allocating all the TMA descriptors and passing them in, the overhead might be too much for small E? NCU has not flagged anything particularly helpful so far. Strangely, for E=2 it shows the kernel is compute bound with 78% compute throughput % and 38% memory bandwidth %. Additional context: torch.compile and handwritten triton kernels were both slow for mxfp8 quant for RHS operands where we scale colwise (32x1 granularity) e.g., (triton hit 56% peak mem bw). So I added a CUDA kernel here which I derived from a TE kernel which achieves ~85% peak mem bw (#2513). Basically we stripped out internal TE types, added support for different scale calculation modes (floor, rceil) to align with torchao numerics, then resolved some perf issues resulting from those changes to get reasonable perf. Now, I'm finding quantizing 3d expert weights along dim1 is scaling extremely poorly as number of experts increases (see this PR's description for details, and see #2999 for benchmarks). So I added a similar CUDA kernel to our mxfp8_cuda extension specifically for quantizing 3d expert weights colwise and writing directly to col major format we need it in. The first approach I tried was just updating the 2d kernel to handle 3d tensors by treating it as a 2d tensor of shape (E*N, K) but the coordinate mapping / pointer arithemetic became a complicated mess that wasn't working. So I made a new kernel, that is similar to the 2d kernel but passes in separate input/output TMA descriptors for each expert, then the kernel operates on each 2d expert with logical separation, in parallel. |
030f4f3 to
4ebcfec
Compare
4ebcfec to
6403a25
Compare
213a554 to
fefb1e0
Compare
stack-info: PR: #3002, branch: danielvegamyhre/stack/69
6403a25 to
a60ee11
Compare
a60ee11 to
6593572
Compare
6dd01fc to
644d635
Compare
stack-info: PR: #3002, branch: danielvegamyhre/stack/69
6593572 to
367b67c
Compare
stack-info: PR: #3002, branch: danielvegamyhre/stack/69
367b67c to
bb8b07f
Compare
stack-info: PR: #3002, branch: danielvegamyhre/stack/69
|
Discussed offline, likely culprit is cudaMallocManaged calls on the hot path https://github.com/pytorch/ao/pull/3002/files#diff-7ddc6623d9efea4ee4f4bdb3cdd7ef16ec3d3bc8bc974be85311125e464efb3dR1326, we should be able to create just a single descriptor and do the loads using the correct offsets. |
I've been trying this today, progress so far is:
This indicates the new input TMA descriptor and async loads are working properly, but the output TMA descriptor and/or async stores are not. I think this is because the input tensor is in simple row major format, which can easily be represented in a TMA descriptor with shape I will push a WIP PR on top of this stack to show the difference between this multiple-tma-descriptor approach, which is at least functionally correct, versus the single tma descriptor approach. |
bb8b07f to
b64758e
Compare
|
@danielvegamyhre I think there's a couple of options, can we do something like: or, we can try using 3d TMAs directly - |
|
@slayton58 row major -> transpose to per expert col major is a great idea! Trying it now.
Yeah I've considered this, it is probably the "proper" way to do it but would require a larger refactor, hopefully the transpose method works. |
|
can we just merge the top of stack w/ this pR |
| size_t SCALE_DIM_X, ScaleCalculationMode ScalingMode> | ||
| __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) | ||
| mxfp8_quantize_kernel_3d( | ||
| const CUtensorMap* tensor_maps_input, |
There was a problem hiding this comment.
Do we need to link against the cuda drivers in order to build this?
I think we fixed this at the top of the stack?
There was a problem hiding this comment.
do you mean in the create tensor map functions? the other examples I saw do reference the cuda driver to get the cuTensorMapEncodeTiled function pointer, so I did the same for 3d tensor map creation.
Stacked PRs:
[mxfp8 moe training] add CUDA kernel to quantize 3d tensor colwise
Summary
.contiguous()calls:to_mxonly scales along the last dim and requires contiguos inputs. So this requires transposing contiguous tensor (E,N,K) -> (E,K,N) then calling .contiguous() to scale along the N dim (needed for backwards)Test plan
Kernel microbenchmarks
Perf is decent for large E and abysmal for small E. Need to investigate this.
Update (9/15): NCU shows 3d kernel operating on (2,8192,5120) tensor is actually compute bound (??)