[mxfp8 moe training] update 3d quant colwise scaling kernel to use single input/output TMA descriptors#3034
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3034
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…ngle input/output TMA descriptors stack-info: PR: #3034, branch: danielvegamyhre/stack/74
4f9a778 to
0f949d8
Compare
60d9553 to
4a2210f
Compare
…ngle input/output TMA descriptors stack-info: PR: #3034, branch: danielvegamyhre/stack/74
4a2210f to
928fe67
Compare
…ngle input/output TMA descriptors stack-info: PR: #3034, branch: danielvegamyhre/stack/74
928fe67 to
ab67f71
Compare
…ngle input/output TMA descriptors stack-info: PR: #3034, branch: danielvegamyhre/stack/74
ab67f71 to
db78695
Compare
…ngle input/output TMA descriptors stack-info: PR: #3034, branch: danielvegamyhre/stack/74
db78695 to
caa7abc
Compare
…ngle input/output TMA descriptors stack-info: PR: #3034, branch: danielvegamyhre/stack/74
caa7abc to
e9d937f
Compare
…ngle input/output TMA descriptors stack-info: PR: #3034, branch: danielvegamyhre/stack/74
e9d937f to
8f06b66
Compare
…ngle input/output TMA descriptors stack-info: PR: #3034, branch: danielvegamyhre/stack/74
333f316 to
af56887
Compare
0f949d8 to
ddcd761
Compare
…ngle input/output TMA descriptors stack-info: PR: #3034, branch: danielvegamyhre/stack/74
af56887 to
3ccbd21
Compare
ddcd761 to
94bd695
Compare
…ngle input/output TMA descriptors stack-info: PR: #3034, branch: danielvegamyhre/stack/74
3ccbd21 to
908f676
Compare
…ngle input/output TMA descriptors stack-info: PR: #3034, branch: danielvegamyhre/stack/74
908f676 to
7fc6c79
Compare
…ngle input/output TMA descriptors stack-info: PR: #3034, branch: danielvegamyhre/stack/74
7fc6c79 to
96ec18a
Compare
|
What are cuda_2d numbers in the benchmark? Running in a loop for each expert? |
The benchmarks are from this script. The "cuda_2d" benchmarks are referencing this function which uses the 2d cuda colwise quantization on the 3d tensor, by reshaping it in pytorch from (E*N, K), quantizing, then reshaping the output and scales appropriately back to 3d, to match the expectations of torch._scaled_grouped_mm mxfp8 2d-3d grouped gemm. The key issue with this method is that the quantized (E*N, K) is in column major format, and I couldn't find a way to reshape/view back to (E,N,K) with per expert column major format by simplying mutating the tensor metadata - I had to do a physical memory layout transformation here, which is not ideal. So I called that method |
| uint32_t shmem_k, | ||
| const size_t type_num_bits) { | ||
| // Get function pointer to cuTensorMapEncodeTiled | ||
| static void *driver_ptr = nullptr; |
There was a problem hiding this comment.
since both 2d and 3d map are using this, I think you should use a function that would return it (to not initialize it twice)
There was a problem hiding this comment.
Hmm good point. Updated to do something like a singleton pattern, holding driver ptr as a global/static var that starts as null then we only initialize it once whenever the first kernel is called. Let me know if that is what you had in mind / will work. (It builds and tests pass)
…ngle input/output TMA descriptors stack-info: PR: #3034, branch: danielvegamyhre/stack/74
96ec18a to
081feea
Compare
| } | ||
| } | ||
|
|
||
| static void *driver_ptr = nullptr; |
There was a problem hiding this comment.
I'd prefer something like
void * get_driver_ptr() {
static void * driver_ptr = nullptr;
if (!driver_ptr) {
cudaDriverEntryPointQueryResult result;
cudaGetDriverEntryPoint("cuTensorMapEncodeTiled", &driver_ptr,
cudaEnableDefault, &result);
}
return driver_ptr;
}
that gets called from both create_3D_tensor_map and create_2D_tensor_map, but this would work too. If you are going with this you should put it in an anonymous namespace to not pollute other files that may include this, driver_ptr is pretty generic.
There was a problem hiding this comment.
Makes sense - updated.
…ngle input/output TMA descriptors stack-info: PR: #3034, branch: danielvegamyhre/stack/74
081feea to
248df72
Compare
| static void *driver_ptr = nullptr; | ||
| if (!driver_ptr) { | ||
| cudaDriverEntryPointQueryResult result; | ||
| cudaGetDriverEntryPoint("cuTensorMapEncodeTiled", &driver_ptr, |
There was a problem hiding this comment.
thanks, done
…ngle input/output TMA descriptors stack-info: PR: #3034, branch: danielvegamyhre/stack/74
248df72 to
28f2ad8
Compare
Stacked PRs:
[mxfp8 moe training] update 3d quant colwise scaling kernel to use single input/output TMA descriptors
Summary
Test plan
sanitize pytest test/prototype/moe_training/test_kernels.pyPerformance