[feat] Trtllm-gen Per-token Nvfp4 MoE#3027
Conversation
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds NVFP4 per‑token quantization producing per‑token scales, extends FP4 quantization with row‑wise and inverse‑scale options, refactors quantization kernels/dispatch and host launchers, threads per‑token scaling through fused MoE runners/launchers/GEMM, and exposes new Python APIs and tests. Changes
Sequence DiagramsequenceDiagram
participant Host as Host
participant Launcher as MoE Launcher
participant NVKernel as NVFP4 Kernel
participant FP4Kernel as FP4 Quant Kernel
participant Gemm2 as Gemm2 / FC2
Host->>Launcher: call fused MoE with input + per_token_scales?
Launcher->>Launcher: allocate workspace (token_scales_fc2) if needed
Launcher->>NVKernel: invokeNvfp4QuantAndPerTokenScale(input, globalScaleInv, sfLayout, ...)
activate NVKernel
NVKernel->>NVKernel: per-row amax reduce → compute per-token scales
NVKernel-->>Launcher: write perTokenScaleOutput, weightOutput, scaleOutput
deactivate NVKernel
Launcher->>FP4Kernel: invokeFP4Quantization(FC1_output, perTokenScales?, use_row_wise/inverse flags)
activate FP4Kernel
FP4Kernel->>FP4Kernel: apply per-token or row-wise/inverse scale, quantize to FP4, emit block scales
FP4Kernel-->>Launcher: packed FP4 activations + block scales
deactivate FP4Kernel
Launcher->>Gemm2: run Gemm2(perTokenScales_fc2, quantized_activation, weights, scales)
activate Gemm2
Gemm2-->>Host: final MoE output
deactivate Gemm2
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces support for per-token scaling in FP4 quantization for MoE models, including updates to the quantization kernels, the MoE runner, and the Python interface. The changes enable row-wise Amax calculation and quantization for FP4, and integrate these into the fused MoE pipeline. My review identified several issues: missing header includes for cuda::std::maximum, potential integer overflow in the Amax kernel, efficiency improvements for the Amax kernel, and the need to expand type support for NvFP4. Additionally, I pointed out unused variables and the use of magic numbers that should be replaced with named constants.
| #include <cudaTypedefs.h> | ||
| #include <float.h> | ||
|
|
||
| #include <cub/cub.cuh> |
There was a problem hiding this comment.
| FLASHINFER_CHECK(mGemm2.mDtypeAct == btg::Dtype::E2m1, | ||
| "Currently only support NvFP4 when using explicit quantization."); |
There was a problem hiding this comment.
The check mGemm2.mDtypeAct == btg::Dtype::E2m1 is too restrictive. NvFP4 also supports MxE2m1 (vector size 32). This check should be expanded to allow MxE2m1, and the subsequent call to invokeFP4Quantization should dispatch to the correct template instantiation (16 or 32) based on the actual type of mGemm2.mDtypeAct.
| template <typename T, uint32_t BLOCK_SIZE> | ||
| __global__ void rowWiseAmaxKernel(uint32_t m, uint32_t n, T const* input, float* amaxOutput, float scale) { | ||
| uint32_t rowIdx = blockIdx.x; | ||
| if (rowIdx >= m) return; | ||
|
|
||
| float localMax = 0.f; | ||
| for (uint32_t colIdx = threadIdx.x; colIdx < n; colIdx += blockDim.x) { | ||
| T element = input[rowIdx * n + colIdx]; | ||
| localMax = fmaxf(localMax, fabsf(static_cast<float>(element) * scale)); | ||
| } | ||
|
|
||
| using BlockReduce = cub::BlockReduce<float, BLOCK_SIZE>; | ||
| __shared__ typename BlockReduce::TempStorage tempStorage; | ||
| float blockMax = BlockReduce(tempStorage) | ||
| .Reduce( | ||
| localMax, | ||
| #if CUDART_VERSION >= 12090 | ||
| cuda::maximum<> {} | ||
| #else | ||
| cub::Max(), | ||
| #endif | ||
| ); | ||
|
|
||
| if (threadIdx.x == 0) { | ||
| amaxOutput[rowIdx] = blockMax; | ||
| } | ||
| } |
There was a problem hiding this comment.
There are two issues in rowWiseAmaxKernel:
-
Efficiency: Multiplying by
scaleinside the loop performsnmultiplications per thread. It is more efficient to compute the maximum of absolute values first and multiply byscaleonce at the end when writing to global memory. -
Correctness (Overflow): The indexing
rowIdx * n + colIdxusesuint32_t. If the product exceeds$2^{32}$ (possible with large token counts and intermediate sizes), it will overflow. Usingstatic_cast<size_t>(rowIdx) * nprevents this.
template <typename T, uint32_t BLOCK_SIZE>
__global__ void rowWiseAmaxKernel(uint32_t m, uint32_t n, T const* input, float* amaxOutput, float scale) {
uint32_t rowIdx = blockIdx.x;
if (rowIdx >= m) return;
float localMax = 0.f;
for (uint32_t colIdx = threadIdx.x; colIdx < n; colIdx += blockDim.x) {
T element = input[static_cast<size_t>(rowIdx) * n + colIdx];
localMax = fmaxf(localMax, fabsf(static_cast<float>(element)));
}
using BlockReduce = cub::BlockReduce<float, BLOCK_SIZE>;
__shared__ typename BlockReduce::TempStorage tempStorage;
float blockMax = BlockReduce(tempStorage)
.Reduce(
localMax,
#if CUDART_VERSION >= 12090
cuda::std::maximum<> {}
#else
cub::Max(),
#endif
);
if (threadIdx.x == 0) {
amaxOutput[rowIdx] = blockMax * scale;
}
}
| workspace.token_scales_fc2 != nullptr, | ||
| "workspace.token_scales_fc2 must be provided When using explicit quantization."); | ||
| const int mMultiProcessorCount = tensorrt_llm::common::getMultiProcessorCount(); | ||
| int intermediate_size_factor = isGatedActivation(args.activation_type) ? 2 : 1; |
| auto sfLayout = mGemm2.mTileTokensDim >= 128 ? QuantizationSFLayout::SWIZZLED_128x4 | ||
| : QuantizationSFLayout::SWIZZLED_8x4; |
| invokeRowWiseAmax<__nv_bfloat16>(workspace.total_max_padded_tokens, args.intermediate_size, | ||
| reinterpret_cast<__nv_bfloat16*>(workspace.gemm1_output), | ||
| reinterpret_cast<float*>(workspace.token_scales_fc2), | ||
| 1.f / 448.f / 6.f, stream); |
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
|
there seems to be some relevant bot run errors |
|
/bot run |
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
ee9e4fb to
8bf05fc
Compare
|
/bot run |
|
Further added TE style reference implementation and |
|
/bot run |
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
|
/bot run |
|
tests look good so far restarting CI for merging |
<!-- .github/pull_request_template.md --> ## 📌 Description Optimize the performance of the per-token nvfp4 quantization kernel introduced by #3027. 1. default block size to 128. 2. default to fast math path. rename `TE_EXACT_FP4` to `TRTLLM_DISABLE_FP4_QUANT_FAST_MATH` and controlled by environmental variable. 3. change argument list of `get_sf_out_offset_128x4` and `get_sf_out_offset_8x4`. TODOs: 1. optimize low latency cases. ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added environment variable configuration to disable fast-math optimization in FP4 quantization, enabling behavior alignment with alternative implementations. * **Tests** * Added test fixture to validate FP4 quantization functionality with fast-math mode disabled. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Siyuan Fu <siyuanf@nvidia.com> Co-authored-by: Ziang Li <ziangli@umich.edu>
📌 Description
coauthor: @mxz297, @zianglih
This PR aims to enable the per-token quantization for Trtllm-gen MoE.
trtllm_fp4_block_scale_moeandtrtllm_fp4_block_scale_routed_moe, added a new optional argumentper_token_scale.cvt_warp_fp16_to_fp4_with_vec_maxto use the cached local amax.🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Chores
Tests