feat: Add support for TRTLLM MXFP8 non-gated MoE with ReLU2#2707
feat: Add support for TRTLLM MXFP8 non-gated MoE with ReLU2#2707aleozlx merged 9 commits intoflashinfer-ai:mainfrom
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the FlashInfer library by integrating support for TRTLLM MXFP8 non-gated Mixture-of-Experts (MoE) layers, specifically tailored for models employing the ReLU2 activation function, such as Nemotron. The changes involve updating core C++ kernels and their Python bindings to correctly manage weight dimensions and activation types, ensuring proper functionality and paving the way for advanced model optimizations and broader compatibility within the vLLM ecosystem. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughThreads ActivationType through FP8/MoE launchers and public APIs, enforces gating-aware activation checks (DeepSeek FP8 limited to Swiglu), adds dynamic DeepSeekV3 top_k limits based on expert count, and expands routed FP8/MXFP8 tests and parity checks. Changes
Sequence Diagram(s)sequenceDiagram
participant Py as Python API (flashinfer.fused_moe)
participant Bind as C Bindings
participant Launcher as Fp8BlockScaleLauncher / FusedMoeLauncher
participant Kernel as CUDA Kernel Launcher
participant GPU as Device
Py->>Bind: trtllm_fp8_*_moe(..., activation_type)
Bind->>Launcher: validateAndCastActivationType(act_type)
Launcher->>Launcher: getValidConfigs(..., activation_type)
Launcher->>Kernel: init(..., activation_type) / launch(configs, weights, inputs)
Kernel->>GPU: run kernels
GPU-->>Kernel: results
Kernel-->>Bind: return outputs
Bind-->>Py: outputs
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request adds support for non-gated MoE with ReLU2 for TRTLLM MXFP8, which is a great feature enhancement. The changes are logical, and new tests provide good coverage for the added functionality. I've found a potential issue where the updated code might not correctly handle quantization modes with different dtypes for activations and weights (like DeepSeekFp8), as it assumes they are the same. I've provided detailed comments and suggestions to make the implementation more robust for all supported quantization modes.
d9d8927 to
279d358
Compare
d102560 to
329ec04
Compare
62ef99d to
0cd2233
Compare
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 1138-1142: The DeepSeek FP8 branch currently constructs
tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner using a constructor that has
no ActivationType (and MoERunnerArgs likewise lacks ActivationType), so gated
activations like Geglu/SwigluBias incorrectly share the Swiglu path; fix by
either (A) restricting the DeepSeekFp8 conditional to only the activation(s) the
current constructor supports (e.g., Swiglu) or (B) add ActivationType to
MoERunnerArgs and use the activation-aware MoE::Runner constructor (and
propagate ActivationType through the call sites that construct MoE::Runner), and
apply the same change to the other analogous DeepSeekFp8 checks in the file.
In `@tests/moe/test_trtllm_gen_routed_fused_moe.py`:
- Line 396: The routed parity test sets use_shuffled_weight=True while
gemm1_weights and gemm2_weights are not passed through the shuffle helpers,
causing the routed kernel to misinterpret raw FP8 weight layout; change
use_shuffled_weight to False in this test (or alternatively apply the same
weight/scale shuffling used in the MXFP8 test) so the routed kernel and the
reference use the same weight layout for gemm1_weights/gemm2_weights.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: fbac1f5e-b044-4d25-a5d6-8c1cb1c283a7
📒 Files selected for processing (5)
csrc/trtllm_fused_moe_kernel_launcher.cuflashinfer/fused_moe/core.pytests/moe/test_trtllm_gen_fused_moe.pytests/moe/test_trtllm_gen_routed_fused_moe.pytests/moe/utils.py
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)
559-575:⚠️ Potential issue | 🟠 MajorDon't make BF16 tactic discovery activation-aware before BF16 execution is.
getValidConfigs()now builds BF16 runners with the caller'sact_type, butBf16MoeLauncher::init()on Line 468 still hard-codesActivationType::Swiglu, andtrtllm_bf16_moe()still has no activation parameter. That means tactic lookup can return/cache configs forRelu2/Gelu/etc. that the BF16 runtime will never execute. Either threadActivationTypethrough the BF16 runtime or reject non-Swigluhere.Suggested guard until the BF16 runtime is activation-aware
static Array<Array<int64_t>> getValidConfigs(int64_t top_k, int64_t hidden_size, int64_t intermediate_size, int64_t num_local_experts, int64_t num_tokens, int64_t act_type, bool use_shuffled_weight, int64_t weight_layout) { Array<Array<int64_t>> valid_configs; + auto activation_type = validateAndCastActivationType(act_type); + TVM_FFI_ICHECK_EQ(activation_type, ActivationType::Swiglu) + << "BF16 valid-config query only supports ActivationType::Swiglu."; std::vector<int32_t> supported_tile_nums(mSupportedTileNums.begin(), mSupportedTileNums.end()); std::set<int32_t> selected_tile_nums = computeSelectedTileN(supported_tile_nums, num_tokens, top_k, num_local_experts); for (int32_t tile_N : selected_tile_nums) { auto moe_runner = std::make_unique<tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner>( btg::Dtype::Bfloat16, // dtype_act btg::Dtype::Bfloat16, // dtype_weights false, // useDeepSeekFp8 - tile_N, static_cast<ActivationType>(act_type), use_shuffled_weight, + tile_N, activation_type, use_shuffled_weight, static_cast<batchedGemm::gemm::MatrixLayout>(weight_layout));🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/trtllm_fused_moe_kernel_launcher.cu` around lines 559 - 575, getValidConfigs() constructs BF16 MoE runners using the caller's act_type which can lead to tactic entries for activations the BF16 runtime cannot run (Bf16MoeLauncher::init() currently hard-codes ActivationType::Swiglu and trtllm_bf16_moe() has no activation parameter); fix by making getValidConfigs() use the runtime-supported activation or reject mismatched activations: either always pass ActivationType::Swiglu when creating the tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner for BF16 paths, or add a guard that checks act_type == ActivationType::Swiglu and skip/return empty configs for other activations, and document/update Bf16MoeLauncher::init()/trtllm_bf16_moe() to thread activation through later when BF16 runtime becomes activation-aware.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 2215-2221: The per-tensor-FP8 branch returns configs for arbitrary
activations even though Fp8PerTensorLauncher still requires gated MOE paths (see
Fp8PerTensorLauncher::check_moe expecting output1_scales_gate_scalar and
prepare_moe allocating 2 * intermediate_size GEMM1); guard the branch that calls
Fp8PerTensorLauncher::getValidConfigs (the branch using act_type /
activation_type) by checking isGatedActivation(activation_type) so it only
returns gated-activation configs until the launcher is updated to honor
intermediate_size_factor and nongated activations.
---
Outside diff comments:
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 559-575: getValidConfigs() constructs BF16 MoE runners using the
caller's act_type which can lead to tactic entries for activations the BF16
runtime cannot run (Bf16MoeLauncher::init() currently hard-codes
ActivationType::Swiglu and trtllm_bf16_moe() has no activation parameter); fix
by making getValidConfigs() use the runtime-supported activation or reject
mismatched activations: either always pass ActivationType::Swiglu when creating
the tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner for BF16 paths, or add a
guard that checks act_type == ActivationType::Swiglu and skip/return empty
configs for other activations, and document/update
Bf16MoeLauncher::init()/trtllm_bf16_moe() to thread activation through later
when BF16 runtime becomes activation-aware.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 1aa14726-c7a5-4c36-9d66-022e0d8647b6
📒 Files selected for processing (1)
csrc/trtllm_fused_moe_kernel_launcher.cu
|
/bot run |
|
[SUCCESS] Pipeline #45677790: 10/20 passed |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 2216-2227: Add the same activation validation used in
trtllm_get_valid_moe_configs to the runtime entry
trtllm_fp8_per_tensor_scale_moe(): call validateAndCastActivationType on the
incoming activation_type, then check isGatedActivation(...) and if false raise
the same NotImplementedError message so nongated per-tensor FP8 paths are
rejected before Fp8PerTensorLauncher::check_moe() or prepare_moe() run; this
prevents the code in Fp8PerTensorLauncher that assumes gated outputs
(output1_scales_gate_scalar and 2 * intermediate_size GEMM1 buffers) from being
executed for unsupported activations.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: f0f019a8-9a24-4b53-be15-4f76e449d96f
📒 Files selected for processing (1)
csrc/trtllm_fused_moe_kernel_launcher.cu
|
I fixed I few things, we can trigger CI again |
|
/bot run |
|
[FAILED] Pipeline #45752288: 8/20 passed |
|
tests seem good. CI also passed pls merge cc @yzh119 |
) Fixes #2731. ## What's broken? When using the CUTLASS fused MoE backend with **non-gated activations** (e.g., Relu2, Gelu, Silu) and MXFP8 quantization, the fc1 weight shape validation unconditionally rejects the input — even when the shape is correct. ## Who is affected? Anyone using the **CUTLASS fused MoE** path with: - **Quantization**: `WMxfp8AMxfp8`, `WMxfp4AFp8`, or `WMxfp4AMxfp8` - **Activation**: any non-gated type (Relu2, Gelu, Silu, etc.) Not affected: gated activations (Swiglu, Geglu, SwigluBias), or other quant modes (NVFP4 already handles this correctly). ## Where is the bug? `csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_binding.cu`, inside `getQuantParams()` — the fc1 weight block N-dimension check hardcodes `* 2` at three MXFP8 branches (~L898, ~L1004, ~L1063). ## Why does it happen? PR #2581 introduced MXFP8 support when only gated activations (Swiglu) existed, so `inter_size * 2` was correct. Later, non-gated activation support was added to the trtllm-gen backend (PR #2707), but the CUTLASS backend's validation was never updated. The NVFP4 path in the same file (line ~1131) already handles this correctly with an `if (isGatedActivation(...))` guard. ## How did we fix it? For each of the 3 MXFP8 quant branches: 1. Extract `int const fc1_n_mult = isGatedActivation(base_activation_type) ? 2 : 1;` 2. Replace the hardcoded `* 2` with `* fc1_n_mult` 3. Update error messages: gated shows `"inter_size * 2"`, non-gated shows `"inter_size"` **Before:** ```cpp fc1_weight_block.size(1) == alignToSfDim(inter_size, ...) * 2 ``` **After:** ```cpp int const fc1_n_mult = isGatedActivation(base_activation_type) ? 2 : 1; fc1_weight_block.size(1) == alignToSfDim(inter_size, ...) * fc1_n_mult ``` ## How do we know it works? - `pre-commit run` passes (clang-format, lint, etc.) - Gated activations (default Swiglu): `fc1_n_mult = 2` — identical to old behavior, no regression - Non-gated activations: `fc1_n_mult = 1` — shape check now accepts correct `inter_size` dimension - Full GPU test suite requires CI (`@flashinfer-bot run`) ## Related - Builds on the approach identified in #2753 (stale ~27 days, CI unresolved). - Addresses the Gemini review feedback from #2753 by extracting the multiplier to a local variable before the validation checks. cc @aleozlx @nv-yunzheq <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Fixed weight block size validation for Mixture of Experts (MOE) to correctly handle both gated and non-gated activation types, ensuring proper support across different activation configurations. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Yiyang Liu <37043548+ianliuy@users.noreply.github.com> Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
📌 Description
This PR adds support for TRTLLM MXFP8 non-gated MoE with ReLU2 (for Nemotron models).
A PR for TRTLLM MXFP8 gated MoE is open in vLLM:
vllm-project/vllm#35986
After this PR is merged and a new flashinfer version is released -
support for non-gated MoE will be added in vLLM.
New tests were added and all tests passed:
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Enhancements
Tests