Skip to content

[feat] Trtllm-gen Per-token Nvfp4 MoE#3027

Merged
aleozlx merged 48 commits intoflashinfer-ai:mainfrom
IwakuraRein:per-token-fp4
May 1, 2026
Merged

[feat] Trtllm-gen Per-token Nvfp4 MoE#3027
aleozlx merged 48 commits intoflashinfer-ai:mainfrom
IwakuraRein:per-token-fp4

Conversation

@IwakuraRein
Copy link
Copy Markdown
Collaborator

@IwakuraRein IwakuraRein commented Apr 9, 2026

📌 Description

coauthor: @mxz297, @zianglih

This PR aims to enable the per-token quantization for Trtllm-gen MoE.

  • In trtllm_fp4_block_scale_moe and trtllm_fp4_block_scale_routed_moe, added a new optional argument per_token_scale.
  • Optimize fp4 quantization kernel. Use 256bit vectorized load. Add cvt_warp_fp16_to_fp4_with_vec_max to use the cached local amax.
  • Add explicit amax and quantization kernel after FC1 to generate the per-token scales for FC2.
  • Generate the new cubins.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Per‑token FP4 quantization with a new nvfp4_quant_and_per_token_scale API exposed to Python.
    • Per‑token scaling integrated end‑to‑end across fused MoE paths; FP4 quantize adds row‑wise and inverse‑global‑scale controls.
    • FP4 API now returns packed FP4 weights, block scales, and per‑token scales.
  • Chores

    • Updated backend artifact reference and checksum for batched GEMM.
  • Tests

    • Added end‑to‑end per‑token routed fused MoE test.

Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 9, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds NVFP4 per‑token quantization producing per‑token scales, extends FP4 quantization with row‑wise and inverse‑scale options, refactors quantization kernels/dispatch and host launchers, threads per‑token scaling through fused MoE runners/launchers/GEMM, and exposes new Python APIs and tests.

Changes

Cohort / File(s) Summary
Quant kernels & helpers
csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh, csrc/nv_internal/tensorrt_llm/kernels/quantization_utils.cuh, csrc/nv_internal/cpp/kernels/quantization.cu
Added packed-vector load utilities and CVT_FP16_TO_FP4_ELTS_PER_THREAD; introduced cvt_warp_fp16_to_fp4_with_vec_max; added NVFP4 per‑token kernels (nvfp4QuantAndPerTokenScale*); refactored FP4/MxFP8 kernels to support USE_ROW_WISE_SCALE/USE_INVERSE_SCALE, renamed SFOuputSFOutput, adjusted block sizing, and updated template/packed-output types.
Host kernel headers & launchers
csrc/nv_internal/tensorrt_llm/kernels/quantization.h, csrc/nv_internal/cpp/kernels/quantization.cu
Added invokeNvfp4QuantAndPerTokenScale and invokeRowWiseAmax; extended invokeFP4Quantization/invokeMxFP8Quantization signatures (renamed SF output param, added use_row_wise_scale and inverse_scale), changed dispatch branches and dynamic shared memory attribute placement.
FP4 Python ops & bindings
csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp, csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.h, flashinfer/quantization/fp4_quantization.py, flashinfer/quantization/__init__.py, flashinfer/fp4_quantization.py, flashinfer/__init__.py
Added SM100 custom op nvfp4_quant_and_per_token_scale, exposed new Python API, added is_global_scale_inversed flag to FP4 quantize callsites, and re-exported the new function.
Fused MoE runner / launcher / headers
include/flashinfer/trtllm/fused_moe/runner.h, csrc/trtllm_fused_moe_runner.cu, csrc/trtllm_fused_moe_kernel_launcher.cu
Threaded per‑token scaling: added usePerTokenScaling and useExplicitQuantization, updated PermuteGemm1/Gemm2 runner constructors and run() signatures to accept per‑token scales, added token_scales_fc2 workspace, and added explicit quant path invoking NVFP4 per‑token kernel.
Batched GEMM runner
include/flashinfer/trtllm/batched_gemm/KernelRunner.h, csrc/trtllm_batched_gemm_runner.cu
Added usePerTokenScaling option and gating in candidate config filtering to skip configs incompatible with per‑token scaling.
Python MoE API & tests
flashinfer/fused_moe/core.py, tests/moe/test_trtllm_gen_per_token_moe.py
Added per_token_scale field and use_per_token_scaling plumbing in MoE inputs/runner/APIs, updated custom‑op signatures to accept per‑token scales, and added a routed per‑token MoE test.
Build/artifacts/jit & enums
flashinfer/artifacts.py, flashinfer/tllm_enums.py, flashinfer/jit/fused_moe.py
Updated TRTLLM_GEN_BMM artifact path/checksum, changed dtype deduction to use element counts (numel()), and included quantization.cu in fused‑MoE JIT compilation units.
Exports & re‑exports
flashinfer/quantization/__init__.py, flashinfer/fp4_quantization.py, flashinfer/__init__.py
Re-exported nvfp4_quant_and_per_token_scale, added Python API wrapper, and updated public FP4 quantization exports.

Sequence Diagram

sequenceDiagram
    participant Host as Host
    participant Launcher as MoE Launcher
    participant NVKernel as NVFP4 Kernel
    participant FP4Kernel as FP4 Quant Kernel
    participant Gemm2 as Gemm2 / FC2

    Host->>Launcher: call fused MoE with input + per_token_scales?
    Launcher->>Launcher: allocate workspace (token_scales_fc2) if needed
    Launcher->>NVKernel: invokeNvfp4QuantAndPerTokenScale(input, globalScaleInv, sfLayout, ...)
    activate NVKernel
    NVKernel->>NVKernel: per-row amax reduce → compute per-token scales
    NVKernel-->>Launcher: write perTokenScaleOutput, weightOutput, scaleOutput
    deactivate NVKernel

    Launcher->>FP4Kernel: invokeFP4Quantization(FC1_output, perTokenScales?, use_row_wise/inverse flags)
    activate FP4Kernel
    FP4Kernel->>FP4Kernel: apply per-token or row-wise/inverse scale, quantize to FP4, emit block scales
    FP4Kernel-->>Launcher: packed FP4 activations + block scales
    deactivate FP4Kernel

    Launcher->>Gemm2: run Gemm2(perTokenScales_fc2, quantized_activation, weights, scales)
    activate Gemm2
    Gemm2-->>Host: final MoE output
    deactivate Gemm2
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

Possibly related PRs

Suggested labels

run-ci

Suggested reviewers

  • bkryu
  • yzh119
  • aleozlx
  • djmmoss
  • cyx-6
  • jimmyzho
  • nv-yunzheq
  • jiahanc

Poem

🐇 I hopped through kernels, tuned each token's scale,
Packed tiny FP4 whispers down a leafy trail.
Per‑token math in burrows tight,
MoE and quant together—what a night!
A rabbit cheers the merge, fast and hale.

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 21.05% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title '[feat] Trtllm-gen Per-token Nvfp4 MoE' clearly describes the main feature: enabling per-token NVFP4 MoE support for TrtLLM-gen.
Description check ✅ Passed The PR description covers the main objectives and includes a checklist, but the test section is incomplete with tests not added.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for per-token scaling in FP4 quantization for MoE models, including updates to the quantization kernels, the MoE runner, and the Python interface. The changes enable row-wise Amax calculation and quantization for FP4, and integrate these into the fused MoE pipeline. My review identified several issues: missing header includes for cuda::std::maximum, potential integer overflow in the Amax kernel, efficiency improvements for the Amax kernel, and the need to expand type support for NvFP4. Additionally, I pointed out unused variables and the use of magic numbers that should be replaced with named constants.

#include <cudaTypedefs.h>
#include <float.h>

#include <cub/cub.cuh>
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The use of cuda::maximum (or cuda::std::maximum) in the rowWiseAmaxKernel requires including <cuda/std/functional>. Without this, compilation might fail depending on the CUDA toolkit version and transitive includes.

#include <cub/cub.cuh>
#include <cuda/std/functional>

Comment on lines +754 to +755
FLASHINFER_CHECK(mGemm2.mDtypeAct == btg::Dtype::E2m1,
"Currently only support NvFP4 when using explicit quantization.");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The check mGemm2.mDtypeAct == btg::Dtype::E2m1 is too restrictive. NvFP4 also supports MxE2m1 (vector size 32). This check should be expanded to allow MxE2m1, and the subsequent call to invokeFP4Quantization should dispatch to the correct template instantiation (16 or 32) based on the actual type of mGemm2.mDtypeAct.

Comment on lines +239 to +265
template <typename T, uint32_t BLOCK_SIZE>
__global__ void rowWiseAmaxKernel(uint32_t m, uint32_t n, T const* input, float* amaxOutput, float scale) {
uint32_t rowIdx = blockIdx.x;
if (rowIdx >= m) return;

float localMax = 0.f;
for (uint32_t colIdx = threadIdx.x; colIdx < n; colIdx += blockDim.x) {
T element = input[rowIdx * n + colIdx];
localMax = fmaxf(localMax, fabsf(static_cast<float>(element) * scale));
}

using BlockReduce = cub::BlockReduce<float, BLOCK_SIZE>;
__shared__ typename BlockReduce::TempStorage tempStorage;
float blockMax = BlockReduce(tempStorage)
.Reduce(
localMax,
#if CUDART_VERSION >= 12090
cuda::maximum<> {}
#else
cub::Max(),
#endif
);

if (threadIdx.x == 0) {
amaxOutput[rowIdx] = blockMax;
}
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There are two issues in rowWiseAmaxKernel:

  1. Efficiency: Multiplying by scale inside the loop performs n multiplications per thread. It is more efficient to compute the maximum of absolute values first and multiply by scale once at the end when writing to global memory.
  2. Correctness (Overflow): The indexing rowIdx * n + colIdx uses uint32_t. If the product exceeds $2^{32}$ (possible with large token counts and intermediate sizes), it will overflow. Using static_cast<size_t>(rowIdx) * n prevents this.
template <typename T, uint32_t BLOCK_SIZE>
__global__ void rowWiseAmaxKernel(uint32_t m, uint32_t n, T const* input, float* amaxOutput, float scale) {
  uint32_t rowIdx = blockIdx.x;
  if (rowIdx >= m) return;

  float localMax = 0.f;
  for (uint32_t colIdx = threadIdx.x; colIdx < n; colIdx += blockDim.x) {
    T element = input[static_cast<size_t>(rowIdx) * n + colIdx];
    localMax = fmaxf(localMax, fabsf(static_cast<float>(element)));
  }

  using BlockReduce = cub::BlockReduce<float, BLOCK_SIZE>;
  __shared__ typename BlockReduce::TempStorage tempStorage;
  float blockMax = BlockReduce(tempStorage)
                       .Reduce(
                           localMax,
#if CUDART_VERSION >= 12090
                           cuda::std::maximum<> {}
#else
                           cub::Max(),
#endif
                       );

  if (threadIdx.x == 0) {
    amaxOutput[rowIdx] = blockMax * scale;
  }
}

Comment thread csrc/trtllm_fused_moe_runner.cu Outdated
workspace.token_scales_fc2 != nullptr,
"workspace.token_scales_fc2 must be provided When using explicit quantization.");
const int mMultiProcessorCount = tensorrt_llm::common::getMultiProcessorCount();
int intermediate_size_factor = isGatedActivation(args.activation_type) ? 2 : 1;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable intermediate_size_factor is calculated but never used in the subsequent logic. It should be removed to keep the code clean.

Comment on lines +762 to +763
auto sfLayout = mGemm2.mTileTokensDim >= 128 ? QuantizationSFLayout::SWIZZLED_128x4
: QuantizationSFLayout::SWIZZLED_8x4;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The selection of sfLayout based on a hardcoded threshold (mTileTokensDim >= 128) is a heuristic. As noted in the FIXME, this should ideally be determined from the actual kernel configuration to ensure compatibility with the GEMM2 kernel's expected layout.

Comment thread csrc/trtllm_fused_moe_runner.cu Outdated
invokeRowWiseAmax<__nv_bfloat16>(workspace.total_max_padded_tokens, args.intermediate_size,
reinterpret_cast<__nv_bfloat16*>(workspace.gemm1_output),
reinterpret_cast<float*>(workspace.token_scales_fc2),
1.f / 448.f / 6.f, stream);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The scale factor 1.f / 448.f / 6.f is a magic number. It should be defined as a named constant (e.g., based on FP8 and FP4 max values) to improve code clarity and maintainability.

Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
@IwakuraRein IwakuraRein marked this pull request as ready for review April 15, 2026 21:24
@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Apr 24, 2026

there seems to be some relevant bot run errors

@IwakuraRein
Copy link
Copy Markdown
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !555 has been updated with latest changes, and the CI pipeline #49357090 is currently running. I'll report back once the pipeline job completes.

Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
@IwakuraRein
Copy link
Copy Markdown
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !555 has been updated with latest changes, and the CI pipeline #49427266 is currently running. I'll report back once the pipeline job completes.

@zianglih
Copy link
Copy Markdown
Contributor

Further added TE style reference implementation and TE_EXACT_NVFP4. Now ["random", "boundary", "zeros", "maxes"] cases in test_nvfp4_per_token_quantize_te_reference are bitwise exact with TE reference implementation.

@IwakuraRein
Copy link
Copy Markdown
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !555 has been updated with latest changes, and the CI pipeline #49669938 is currently running. I'll report back once the pipeline job completes.

Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
@IwakuraRein
Copy link
Copy Markdown
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !555 has been updated with latest changes, and the CI pipeline #49770024 is currently running. I'll report back once the pipeline job completes.

@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Apr 30, 2026

tests look good so far

restarting CI for merging

@aleozlx aleozlx enabled auto-merge (squash) April 30, 2026 21:00
@aleozlx aleozlx merged commit 537a3b5 into flashinfer-ai:main May 1, 2026
30 checks passed
aleozlx pushed a commit that referenced this pull request May 8, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

Optimize the performance of the per-token nvfp4 quantization kernel
introduced by #3027.

1. default block size to 128.
2. default to fast math path. rename `TE_EXACT_FP4` to
`TRTLLM_DISABLE_FP4_QUANT_FAST_MATH` and controlled by environmental
variable.
3. change argument list of `get_sf_out_offset_128x4` and
`get_sf_out_offset_8x4`.

TODOs:
1. optimize low latency cases.

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **New Features**
* Added environment variable configuration to disable fast-math
optimization in FP4 quantization, enabling behavior alignment with
alternative implementations.

* **Tests**
* Added test fixture to validate FP4 quantization functionality with
fast-math mode disabled.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Co-authored-by: Ziang Li <ziangli@umich.edu>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants