Skip to content

[mxfp8 moe training] update 3d quant colwise scaling kernel to use single input/output TMA descriptors#3034

Merged
danielvegamyhre merged 1 commit into
mainfrom
danielvegamyhre/stack/74
Sep 20, 2025
Merged

[mxfp8 moe training] update 3d quant colwise scaling kernel to use single input/output TMA descriptors#3034
danielvegamyhre merged 1 commit into
mainfrom
danielvegamyhre/stack/74

Conversation

@danielvegamyhre

@danielvegamyhre danielvegamyhre commented Sep 19, 2025

Copy link
Copy Markdown
Contributor

Stacked PRs:


[mxfp8 moe training] update 3d quant colwise scaling kernel to use single input/output TMA descriptors

Summary

  • CUDA Kernel for 3d quantization across cols added in [mxfp8 moe training] wrap 3d quantize tensor in custom ops and integrate it #3004 has worse perf than other methods for small num_experts, and is only better for large num_experts.
  • We hypothesize this is because cudaMallocManaged of separate TMA descriptors per expert, which is a slow/blocking function. The overhead is constant, and thus more noticeable for small inputs.
  • In this PR, I redesign the kernel to use single input and output TMA descriptor for the whole 3d tensor.
    • For the input, it is in simple row major format, so I can read from specific experts by adjusting the TMA row offset during the async TMA load.
    • For the output, it is a more complex "column major per expert" format, so I use a 3d TMA descriptor with the specific shape and strides needed. I transpose the row major data in SMEM before doing the async TMA store to GMEM to get it in col major per expert format.

Test plan

  • Add unit tests for Llama4 and DeepSeekV3 shapes
  • sanitize pytest test/prototype/moe_training/test_kernels.py

Performance

input_shape         to_mx_us    cuda_2d_us    cuda_3d_us    to_mx_gbps    cuda_2d_gbps    cuda_3d_gbps
----------------  ----------  ------------  ------------  ------------  --------------  --------------
(1, 8192, 5120)      118.656        68.016        34.848      1071.5           1869.26         3648.41
(2, 8192, 5120)      430.064       105.472        61.44        591.26          2410.87         4138.67
(4, 8192, 5120)      847.824       197.632       118.784       599.841         2573.26         4281.38
(8, 8192, 5120)     1693.76        378.848       252.832       600.509         2684.77         4022.9
(16, 8192, 5120)    3449.66        742.368       489.44        589.691         2740.2          4156.26
(64, 8192, 5120)   13354          3145.7        1814.62        609.328         2586.69         4484.1

@pytorch-bot

pytorch-bot Bot commented Sep 19, 2025

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3034

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

danielvegamyhre added a commit that referenced this pull request Sep 19, 2025
…ngle input/output TMA descriptors

stack-info: PR: #3034, branch: danielvegamyhre/stack/74
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 19, 2025
@danielvegamyhre danielvegamyhre marked this pull request as draft September 19, 2025 03:51
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/73 to main September 19, 2025 03:54
danielvegamyhre added a commit that referenced this pull request Sep 19, 2025
…ngle input/output TMA descriptors

stack-info: PR: #3034, branch: danielvegamyhre/stack/74
danielvegamyhre added a commit that referenced this pull request Sep 19, 2025
…ngle input/output TMA descriptors

stack-info: PR: #3034, branch: danielvegamyhre/stack/74
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/73 September 19, 2025 03:55
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/73 to main September 19, 2025 18:00
danielvegamyhre added a commit that referenced this pull request Sep 19, 2025
…ngle input/output TMA descriptors

stack-info: PR: #3034, branch: danielvegamyhre/stack/74
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/73 September 19, 2025 18:00
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/73 to main September 19, 2025 18:10
danielvegamyhre added a commit that referenced this pull request Sep 19, 2025
…ngle input/output TMA descriptors

stack-info: PR: #3034, branch: danielvegamyhre/stack/74
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/73 September 19, 2025 18:10
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/73 to main September 19, 2025 18:29
danielvegamyhre added a commit that referenced this pull request Sep 19, 2025
…ngle input/output TMA descriptors

stack-info: PR: #3034, branch: danielvegamyhre/stack/74
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/73 September 19, 2025 18:29
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/73 to main September 19, 2025 18:36
danielvegamyhre added a commit that referenced this pull request Sep 19, 2025
…ngle input/output TMA descriptors

stack-info: PR: #3034, branch: danielvegamyhre/stack/74
Comment thread torchao/csrc/cuda/mx_kernels/mxfp8_quantize.cuh Outdated
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/73 to main September 19, 2025 20:42
danielvegamyhre added a commit that referenced this pull request Sep 19, 2025
…ngle input/output TMA descriptors

stack-info: PR: #3034, branch: danielvegamyhre/stack/74
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/73 September 19, 2025 20:42
danielvegamyhre added a commit that referenced this pull request Sep 19, 2025
…ngle input/output TMA descriptors

stack-info: PR: #3034, branch: danielvegamyhre/stack/74
danielvegamyhre added a commit that referenced this pull request Sep 19, 2025
…ngle input/output TMA descriptors

stack-info: PR: #3034, branch: danielvegamyhre/stack/74
danielvegamyhre added a commit that referenced this pull request Sep 19, 2025
…ngle input/output TMA descriptors

stack-info: PR: #3034, branch: danielvegamyhre/stack/74
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/73 to main September 19, 2025 20:59
danielvegamyhre added a commit that referenced this pull request Sep 19, 2025
…ngle input/output TMA descriptors

stack-info: PR: #3034, branch: danielvegamyhre/stack/74
@ngimel

ngimel commented Sep 19, 2025

Copy link
Copy Markdown

What are cuda_2d numbers in the benchmark? Running in a loop for each expert?

@danielvegamyhre

danielvegamyhre commented Sep 19, 2025

Copy link
Copy Markdown
Contributor Author

What are cuda_2d numbers in the benchmark? Running in a loop for each expert?

The benchmarks are from this script.

The "cuda_2d" benchmarks are referencing this function which uses the 2d cuda colwise quantization on the 3d tensor, by reshaping it in pytorch from (E*N, K), quantizing, then reshaping the output and scales appropriately back to 3d, to match the expectations of torch._scaled_grouped_mm mxfp8 2d-3d grouped gemm.

The key issue with this method is that the quantized (E*N, K) is in column major format, and I couldn't find a way to reshape/view back to (E,N,K) with per expert column major format by simplying mutating the tensor metadata - I had to do a physical memory layout transformation here, which is not ideal.

So I called that method cuda_2d in the script since it's using the 2d quantization kernel (name could be better, just trying to keep it concise).

uint32_t shmem_k,
const size_t type_num_bits) {
// Get function pointer to cuTensorMapEncodeTiled
static void *driver_ptr = nullptr;

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since both 2d and 3d map are using this, I think you should use a function that would return it (to not initialize it twice)

@danielvegamyhre danielvegamyhre Sep 19, 2025

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm good point. Updated to do something like a singleton pattern, holding driver ptr as a global/static var that starts as null then we only initialize it once whenever the first kernel is called. Let me know if that is what you had in mind / will work. (It builds and tests pass)

danielvegamyhre added a commit that referenced this pull request Sep 19, 2025
…ngle input/output TMA descriptors

stack-info: PR: #3034, branch: danielvegamyhre/stack/74
}
}

static void *driver_ptr = nullptr;

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer something like

void * get_driver_ptr() {
static void * driver_ptr = nullptr;
  if (!driver_ptr) {
    cudaDriverEntryPointQueryResult result;
    cudaGetDriverEntryPoint("cuTensorMapEncodeTiled", &driver_ptr,
                            cudaEnableDefault, &result);
  }
  return driver_ptr;
}

that gets called from both create_3D_tensor_map and create_2D_tensor_map, but this would work too. If you are going with this you should put it in an anonymous namespace to not pollute other files that may include this, driver_ptr is pretty generic.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense - updated.

danielvegamyhre added a commit that referenced this pull request Sep 19, 2025
…ngle input/output TMA descriptors

stack-info: PR: #3034, branch: danielvegamyhre/stack/74
static void *driver_ptr = nullptr;
if (!driver_ptr) {
cudaDriverEntryPointQueryResult result;
cudaGetDriverEntryPoint("cuTensorMapEncodeTiled", &driver_ptr,

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, you should error check this call

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, done

…ngle input/output TMA descriptors

stack-info: PR: #3034, branch: danielvegamyhre/stack/74
@danielvegamyhre danielvegamyhre merged commit d2fae7a into main Sep 20, 2025
14 of 17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: not user facing Use this tag if you don't want this PR to show up in release notes moe mx

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants