Skip to content

Feature: Support Relu2 activation in fused MoE#1954

Merged
yzh119 merged 4 commits intoflashinfer-ai:mainfrom
amirkl94:feat/relu2-act
Oct 21, 2025
Merged

Feature: Support Relu2 activation in fused MoE#1954
yzh119 merged 4 commits intoflashinfer-ai:mainfrom
amirkl94:feat/relu2-act

Conversation

@amirkl94
Copy link
Copy Markdown
Contributor

@amirkl94 amirkl94 commented Oct 20, 2025

📌 Description

Added support for Relu2 activation in cutlass fp8 FusedMoE path.
Relu2(x) = Relu(x)^2.

Validated this works correctly on H100 and B200.

Summary by CodeRabbit

  • New Features
    • Added Relu2 as a selectable activation across MOE operations and exposed activation_type configuration to public MOE APIs and runner interfaces (Swiglu remains the default).
  • Behavior
    • Certain GEMM execution paths now explicitly reject Relu2 and raise a clear runtime error instead of silently failing.

Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com>
Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com>
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @amirkl94, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a new Relu2 activation function to the FP8 fused Mixture of Experts (MoE) implementation, leveraging the CUTLASS backend. This enhancement provides greater flexibility in model design by expanding the available activation options. The changes involve both the low-level C++ kernel definitions and the higher-level Python bindings, ensuring seamless integration and proper handling of the new activation type, including adaptive weight shape checks. The functionality has been verified on recent NVIDIA GPU hardware.

Highlights

  • New Activation Function: Added support for the Relu2 activation function, defined as Relu(x)^2, within the CUTLASS FP8 Fused Mixture of Experts (MoE) path.
  • Kernel Integration: Integrated the Relu2 activation kernel into the doActivation function in cutlass_fused_moe_kernels.cuh and defined its structure in fused_activations.h.
  • Dynamic Weight Shape Validation: Modified the FusedMoeRunner to dynamically adjust the fc1_expert_weights shape validation based on whether the chosen activation type is gated (like Swiglu) or non-gated (like Relu2).
  • API Extension: Extended the C++ FusedMoeRunner and Python cutlass_fused_moe functions to accept an ActivationType parameter, allowing users to specify Relu2.
  • Validation: The new Relu2 activation has been validated to work correctly on H100 and B200 GPU architectures.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Oct 20, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Adds a Relu2 activation implementation, threads an activation_type parameter through Python and C++ bindings into kernel selection, registers Relu2 in the kernel dispatch, and also adds an explicit runtime error path declaring Relu2 unsupported in the MoeGemm runner dispatch.

Changes

Cohort / File(s) Summary
Relu2 implementation
csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/thread/fused_activations.h, csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/common.h
Adds cutlass::epilogue::thread::Relu2 (squared ReLU) and adds Relu2 enumerator to ActivationType
Kernel dispatch registration
csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh
Adds IdentityAdaptor<cutlass::epilogue::thread::Relu2> to the activation kernel dispatch list
Bindings: activation_type parameter & propagation
csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu
Adds ActivationType base_activation_type parameter (default Swiglu) to public bindings, threads activation_type into validation, ActivationParams, workspace/profiling and runtime paths
Python API and wiring
flashinfer/fused_moe/core.py
Adds ActivationType IntEnum, accepts/stores activation_type in MoERunner, and propagates it through cutlass_fused_moe and internal wrappers (default Swiglu)
MoeGemm dispatch explicit rejection
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h
Adds an explicit case ActivationType::Relu2 that raises a runtime error ("Relu2 is not supported.") in moeGemmBiasAct

Sequence Diagram

sequenceDiagram
    participant Py as Python (flashinfer/fused_moe/core.py)
    participant Bind as C++ Binding (flashinfer_cutlass_fused_moe_sm100_binding.cu)
    participant Kernel as Kernel Dispatch (cutlass_fused_moe_kernels.cuh)
    participant Act as Relu2 (fused_activations.h)
    participant MoeGemm as MoeGemm Dispatch (moe_gemm_template_dispatch.h)

    Py->>Bind: run_gemm_profile(..., activation_type)
    Bind->>Bind: validate weights & build ActivationParams using activation_type
    Bind->>Kernel: request kernel variant for activation_type
    alt activation_type == Relu2
        Kernel->>Act: use Relu2 adaptor (compute ReLU(value)²)
        Act->>Kernel: return activation result
        Kernel->>Bind: kernel completes
        Bind->>Py: return success
    else other activations
        Kernel->>Bind: selected other activation path
        Bind->>Py: return success
    end
    note right of MoeGemm: Separate check in MoeGemm runner
    Bind->>MoeGemm: invoke moeGemmBiasAct(..., activation_type)
    alt MoeGemm receives Relu2
        MoeGemm-->>Bind: throw runtime error "Relu2 is not supported."
        Bind-->>Py: propagate error
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

Poem

🐰 I hopped through bindings, C++ and C,
Found a Relu that squares with glee.
Registered, wired, then loudly cried,
"If you call me where I mustn't stride!"
Still I twitch my nose — new code, new tree.

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Description Check ⚠️ Warning The PR description is significantly incomplete compared to the provided template. While the author provides an adequate Description section explaining the change (Relu2 activation support with definition and validation), the description entirely omits the Related Issues section and fails to include any of the PR Checklist components (Pre-commit Checks, Tests, and Reviewer Notes). The template clearly specifies these sections as required, and their complete absence indicates the author did not follow the repository's documentation standards, even though the substantive content that is present is relevant and clear. Add the missing sections to the PR description: include a "Related Issues" section linking any relevant GitHub issues, fill out the Pre-commit Checks checklist to confirm pre-commit hooks were run, add a Tests section confirming tests have been added or updated and are passing, and optionally include any Reviewer Notes. Following the template structure ensures consistency and helps reviewers understand the full context of the changes.
✅ Passed checks (2 passed)
Check name Status Explanation
Title Check ✅ Passed The title "Feature: Support Relu2 activation in fused MoE" directly and accurately describes the main change in the PR. The changeset consistently adds Relu2 activation support across the CUTLASS backend, Python bindings, and core API of the fused MoE implementation. The title is specific, clear, and concise without unnecessary noise, making it easy for teammates scanning history to understand the primary change.
Docstring Coverage ✅ Passed No functions found in the changes. Docstring coverage check skipped.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5c7da20 and f503837.

📒 Files selected for processing (1)
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h (2)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/common.h (1)
  • ActivationType (22-34)
flashinfer/fused_moe/core.py (1)
  • ActivationType (77-86)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h (1)

930-932: Remove unreachable break and clarify Relu2 support scope.

Relu2 activation is intentionally unsupported in MoeGemmRunner::moeGemmBiasAct, but it IS supported in the fused MoE path via cutlass_fused_moe_kernels.cuh (which uses IdentityAdaptor<cutlass::epilogue::thread::Relu2>). This is by design—not all activation types are available in all code paths.

Improvements needed:

  1. Remove unreachable code: The break statement after TLLM_THROW is unreachable (assuming TLLM_THROW terminates execution).

  2. Provide clarity in error message: Help users understand the limitation:

 case ActivationType::Relu2:
-  TLLM_THROW("Relu2 is not supported.");
-  break;
+  TLLM_THROW("Relu2 activation is not supported in MoeGemmRunner. Use the fused MoE path (cutlass_fused_moe_kernels) for Relu2 support.");

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for the Relu2 activation function in the fused MoE kernels. The changes are well-structured, touching both the C++ backend and the Python frontend to plumb through the new activation type. My review identified a critical compilation error due to a typo, some dead code that should be removed, and a couple of minor style/maintainability issues. Once these are addressed, the PR should be in good shape.

Comment thread csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu Outdated
Comment on lines +298 to +300
TVM_FFI_ICHECK_EQ(fc1_expert_weights->shape[1],
fc2_expert_weights->shape[2] * mInnerDimMultiplier)
<< "fc1_expert_weights inter size must be equal to fc2_expert_weights inter size.";
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The indentation for the TVM_FFI_ICHECK_EQ macro in the else block is inconsistent with the if block. For better readability and to maintain a consistent code style, it should be aligned with the check in the if block.

      TVM_FFI_ICHECK_EQ(fc1_expert_weights->shape[1],
                        fc2_expert_weights->shape[2] * mInnerDimMultiplier)
          << "fc1_expert_weights inter size must be equal to fc2_expert_weights inter size.";

Comment on lines +470 to +472
TVM_FFI_ICHECK_EQ(fc1_expert_weights->shape[1],
fc2_expert_weights->shape[2] * mInnerDimMultiplier)
<< "fc1_expert_weights inter size must be equal to fc2_expert_weights inter size.";
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The indentation for the TVM_FFI_ICHECK_EQ macro in the else block is inconsistent with the if block. For better readability and to maintain a consistent code style, it should be aligned with the check in the if block.

      TVM_FFI_ICHECK_EQ(fc1_expert_weights->shape[1],
                        fc2_expert_weights->shape[2] * mInnerDimMultiplier)
          << "fc1_expert_weights inter size must be equal to fc2_expert_weights inter size.";

Comment thread flashinfer/fused_moe/core.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (4)
csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh (1)

2375-2383: Fix clang-format violations before merge.

GitHub Actions flagged this file for trailing whitespace/formatting. Please run clang-format (or the project’s formatting hook) over the file so CI can pass.

csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue_helpers.h (1)

96-103: Resolve the clang-format findings.

CI flagged this header for trailing whitespace/formatting. Please run the project’s formatting step (clang-format/pre-commit) so the pipeline passes.

csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu (2)

589-599: Typo breaks build: AcitvationType → ActivationType in runGemmProfile.

The parameter type name is misspelled and will not compile.

-                      int64_t gemm_idx, int64_t profile_id, bool do_preparation, bool enable_pdl, AcitvationType activation_type) {
+                      int64_t gemm_idx, int64_t profile_id, bool do_preparation, bool enable_pdl, ActivationType activation_type) {

492-497: Missing stream insertion in TVM_FFI_ICHECK_EQ message.

The message for swiglu_beta lacks <<, causing a compile error.

-      TVM_FFI_ICHECK_EQ(swiglu_beta.value()->shape[0], num_experts_on_rank)
-      "swiglu_beta must have num_experts_on_rank elements.";
+      TVM_FFI_ICHECK_EQ(swiglu_beta.value()->shape[0], num_experts_on_rank)
+          << "swiglu_beta must have num_experts_on_rank elements.";
🧹 Nitpick comments (2)
csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/thread/fused_activations.h (1)

64-79: Drop the unused maximum instance.

maximum<T, PropagateNaN> mx; is never referenced; keeping it will trigger unused-variable warnings with stricter builds. You can remove it and rely on ReLu<T> directly.

flashinfer/fused_moe/core.py (1)

75-79: Review comment verified; refactoring to complete enum mirroring is recommended.

The verification confirms that Swiglu=3 and Relu2=6 correctly match the C++ header. However, the Python enum at flashinfer/fused_moe/core.py (lines 76-78) is incomplete—it defines only 2 of 9 members present in the C++ definition. The C++ header itself contains a comment on line 21 stating "Note update flashinfer/fused_moe/core.py to match", flagging this as a known sync issue.

The Python enum is missing: Gelu (0), Relu (1), Silu (2), Geglu (4), SwigluBias (5), Identity (7), and InvalidType (8). The suggestion to mirror all C++ values or add explicit comments is sound for reducing drift and improving maintainability.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d84e1d5 and bb96c22.

📒 Files selected for processing (6)
  • csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh (1 hunks)
  • csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu (9 hunks)
  • csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/thread/fused_activations.h (1 hunks)
  • csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue_helpers.h (1 hunks)
  • csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/common.h (1 hunks)
  • flashinfer/fused_moe/core.py (10 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/common.h (1)
flashinfer/fused_moe/core.py (1)
  • ActivationType (76-78)
flashinfer/fused_moe/core.py (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/common.h (1)
  • ActivationType (22-34)
csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h (1)
  • isGatedActivation (253-256)
🪛 GitHub Actions: pre-commit
csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue_helpers.h

[error] 1-1: Trailing whitespace detected. Fixes applied by pre-commit.


[error] 1-1: clang-format formatting failed. Run 'clang-format' to fix code style issues in this file.

csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh

[error] 1-1: Trailing whitespace detected. Fixes applied by pre-commit.


[error] 1-1: clang-format formatting failed. Run 'clang-format' to fix code style issues in this file.

csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu

[error] 1-1: Trailing whitespace detected. Fixes applied by pre-commit.


[error] 1-1: clang-format formatting failed. Run 'clang-format' to fix code style issues in this file.

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (4)
flashinfer/fused_moe/core.py (2)

312-339: Good threading of activation_type through runner and profile.

Constructor storage and forward call propagation look correct. No cache-key change needed since activation only affects call-time behavior.

Also applies to: 392-393


645-652: API default and docs for Relu2.

Defaulting to Swiglu and documenting “Relu2 is non‑gated GEMM1” is clear. Propagation in the wrapper call is correct.

Also applies to: 749-752, 832-833

csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu (2)

640-649: Profiling path: passing activation_type into mProfiler->init looks correct.

Once the type name typo is fixed, this should wire through cleanly.

Please confirm GemmProfilerBackend::init expects ActivationType at that position in both OSS and non‑OSS builds.


1-10: Commit clang-format changes to the file.

Clang-format found and corrected formatting issues in this file. The changes have been applied and need to be committed.

Comment on lines +441 to 442
activation_type: ActivationType = ActivationType.Swiglu,
) -> List[torch.Tensor]:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Keep fake op signature in sync (adds activation_type).

Public API added activation_type with default. The registered fake op flashinfer::cutlass_fused_moe still lacks this arg and will error under meta/fake dispatch. Add it to the fake op signature.

Apply this diff in the fake op to maintain parity:

 @register_fake_op("flashinfer::cutlass_fused_moe")
 def _fake_cutlass_fused_moe(
     output: torch.Tensor,
     input: torch.Tensor,
     token_selected_experts: torch.Tensor,
     token_final_scales: torch.Tensor,
     fc1_expert_weights: torch.Tensor,
     fc1_expert_biases: Optional[torch.Tensor],
     fc2_expert_weights: torch.Tensor,
     fc2_expert_biases: Optional[torch.Tensor],
     output_dtype: torch.dtype,
     quant_scales: List[torch.Tensor],
     input_sf: Optional[torch.Tensor] = None,
     swiglu_alpha: Optional[torch.Tensor] = None,
     swiglu_beta: Optional[torch.Tensor] = None,
     swiglu_limit: Optional[torch.Tensor] = None,
     tp_size: int = 1,
     tp_rank: int = 0,
     ep_size: int = 1,
     ep_rank: int = 0,
     cluster_size: int = 1,
     cluster_rank: int = 0,
     enable_alltoall: bool = False,
     use_deepseek_fp8_block_scale: bool = False,
     use_w4_group_scaling: bool = False,
     use_mxfp8_act_scaling: bool = False,
     min_latency_mode: bool = False,
     tune_max_num_tokens: int = 8192,
     enable_pdl: Optional[bool] = None,
+    activation_type: ActivationType = ActivationType.Swiglu,
 ):

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In flashinfer/fused_moe/core.py around lines 441-442, the public API added the
activation_type: ActivationType = ActivationType.Swiglu parameter but the
registered fake op flashinfer::cutlass_fused_moe signature was not updated;
update the fake op registration/signature to include the activation_type
parameter with the same name and default (ActivationType.Swiglu) so the
meta/fake dispatcher sees a matching signature, and ensure any internal
registration/schema and wrapper call sites use the new parameter name and
default.

Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu (1)

495-499: Compile error: missing stream insertion in TVM_FFI_ICHECK_EQ.

The swiglu_beta check lacks the << before the message literal.

Apply this diff:

-      TVM_FFI_ICHECK_EQ(swiglu_beta.value()->shape[0], num_experts_on_rank)
-      "swiglu_beta must have num_experts_on_rank elements.";
+      TVM_FFI_ICHECK_EQ(swiglu_beta.value()->shape[0], num_experts_on_rank)
+          << "swiglu_beta must have num_experts_on_rank elements.";
flashinfer/fused_moe/core.py (1)

561-571: Return contract inverted for min_latency_mode: confirmed bug requires immediate fix.

The code allocates and passes three extra tensors to the kernel ONLY when min_latency_mode=True (line 524-530), yet returns them when False. Lines 561-570 return a single tensor when True (missing the three allocated outputs) and attempts to return four tensors when False (where only one was allocated).

The fake op correctly returns 4 items when min_latency_mode=True and 1 item when False, confirming the real op's logic is inverted. While currently latent since the public wrapper raises NotImplementedError for min_latency_mode, this must be fixed before that mode is implemented.

Apply the provided diffs to lines 561-570 (main function) and 605-617 (fake op) to align return behavior with allocation semantics.

♻️ Duplicate comments (3)
flashinfer/fused_moe/core.py (2)

76-87: Keep ActivationType in Python in lockstep with C++ source of truth.

Enum values currently match common.h, but this local copy can drift. Consider auto‑generating from the C++ header at build time or adding a cheap import‑time check that compares numeric values returned from a tiny C++ helper. At minimum, document the sync contract next to this enum.


572-601: Fake op signature missing activation_type (breaks meta/fake dispatch).

Add the new parameter with same name/default as the real op. Otherwise, fake/meta kernels will error on kwargs mismatch.

Apply this diff:

 @register_fake_op("flashinfer::cutlass_fused_moe")
 def _fake_cutlass_fused_moe(
     output: torch.Tensor,
     input: torch.Tensor,
     token_selected_experts: torch.Tensor,
     token_final_scales: torch.Tensor,
     fc1_expert_weights: torch.Tensor,
     fc1_expert_biases: Optional[torch.Tensor],
     fc2_expert_weights: torch.Tensor,
     fc2_expert_biases: Optional[torch.Tensor],
     output_dtype: torch.dtype,
     quant_scales: List[torch.Tensor],
     input_sf: Optional[torch.Tensor] = None,
     swiglu_alpha: Optional[torch.Tensor] = None,
     swiglu_beta: Optional[torch.Tensor] = None,
     swiglu_limit: Optional[torch.Tensor] = None,
     tp_size: int = 1,
     tp_rank: int = 0,
     ep_size: int = 1,
     ep_rank: int = 0,
     cluster_size: int = 1,
     cluster_rank: int = 0,
     enable_alltoall: bool = False,
     use_deepseek_fp8_block_scale: bool = False,
     use_w4_group_scaling: bool = False,
     use_mxfp8_act_scaling: bool = False,
     min_latency_mode: bool = False,
     tune_max_num_tokens: int = 8192,
     enable_pdl: Optional[bool] = None,
+    activation_type: ActivationType = ActivationType.Swiglu,
 ):
csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu (1)

365-367: Quant-scale shape checks still assume gated (×2) for FC1; plumb activation to getQuantParams.

Shape validations for FC1 block scales hardcode inter_size * 2. That rejects valid Relu2 (non‑gated) layouts and mis-sizes workspaces. Pass ActivationType into getQuantParams and gate the multiplier.

Apply these diffs:

  1. Thread activation into calls:
-    auto const quant_params =
-        getQuantParams(num_experts_on_rank, hidden_size, inter_size, quant_scales);
+    auto const quant_params =
+        getQuantParams(num_experts_on_rank, hidden_size, inter_size, quant_scales, base_activation_type);
-    auto const quant_params =
-        getQuantParams(num_experts_on_rank, hidden_size, inter_size, quant_scales);
+    auto const quant_params =
+        getQuantParams(num_experts_on_rank, hidden_size, inter_size, quant_scales, base_activation_type);
  1. Update signature and gate multiplier:
-  kernels::QuantParams getQuantParams(int64_t num_experts_on_rank, int64_t hidden_size,
-                                      int64_t inter_size,
-                                      Optional<Array<Tensor>> quant_scales) const {
+  kernels::QuantParams getQuantParams(int64_t num_experts_on_rank, int64_t hidden_size,
+                                      int64_t inter_size, Optional<Array<Tensor>> quant_scales,
+                                      ActivationType base_activation_type) const {
+    bool const is_gated = isGatedActivation(base_activation_type);
  1. Fix FC1 shape checks (examples shown for each block):

W4A8_MXFP4_FP8:

-      TVM_FFI_ICHECK(
+      TVM_FFI_ICHECK(
           fc1_weight_block->shape[0] == num_experts_on_rank &&
           fc1_weight_block->shape[1] ==
               TmaWarpSpecializedGroupedGemmInput::alignToSfDim(
-                  inter_size, TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX) *
-                  2 &&
+                  inter_size, TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX) *
+                  (is_gated ? 2 : 1) &&
           ...
-          << "fc1 weight block size must be (num_experts_on_rank, inter_size * 2, hidden_size // 4 "
+          << "fc1 weight block size must be (num_experts_on_rank, inter_size * (is_gated?2:1), hidden_size // 4 "
              "// block_scale_vector_size)";

W4A8_MXFP4_MXFP8:

-      TVM_FFI_ICHECK(
+      TVM_FFI_ICHECK(
           fc1_weight_block->shape[0] == num_experts_on_rank &&
           fc1_weight_block->shape[1] ==
               TmaWarpSpecializedGroupedGemmInput::alignToSfDim(
-                  inter_size, TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX) *
-                  2 &&
+                  inter_size, TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX) *
+                  (is_gated ? 2 : 1) &&
           ...
-          << "fc1 weight block size must be (num_experts_on_rank, inter_size * 2, hidden_size // 4 "
+          << "fc1 weight block size must be (num_experts_on_rank, inter_size * (is_gated?2:1), hidden_size // 4 "
              "// block_scale_vector_size)";

NVFP4:

-      TVM_FFI_ICHECK(
+      TVM_FFI_ICHECK(
           fc1_weight_block->shape[0] == num_experts_on_rank &&
           fc1_weight_block->shape[1] ==
               TmaWarpSpecializedGroupedGemmInput::alignToSfDim(
-                  inter_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentNVFP4) *
-                  2 &&
+                  inter_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentNVFP4) *
+                  (is_gated ? 2 : 1) &&
           ...
-          << "fc1 weight block size must be (num_experts_on_rank, inter_size * 2, hidden_size // 4 "
+          << "fc1 weight block size must be (num_experts_on_rank, inter_size * (is_gated?2:1), hidden_size // 4 "
              "// block_scale_vector_size)";

Also applies to: 542-544, 807-813, 871-882, 930-941, 1001-1011

🧹 Nitpick comments (1)
flashinfer/fused_moe/core.py (1)

655-656: Docstring/type hint don’t match actual return shape.

Signature and docs promise torch.Tensor, but the op returns a list when min_latency_mode flips. After fixing the inversion, either:

  • document Union[torch.Tensor, List[torch.Tensor]], or
  • normalize to always return a torch.Tensor from this wrapper.

Pick one and update annotations/docs accordingly.

Also applies to: 759-761

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between bb96c22 and 5c7da20.

📒 Files selected for processing (4)
  • csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh (1 hunks)
  • csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu (9 hunks)
  • csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/thread/fused_activations.h (1 hunks)
  • flashinfer/fused_moe/core.py (10 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/thread/fused_activations.h
  • csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh
🧰 Additional context used
🧬 Code graph analysis (2)
flashinfer/fused_moe/core.py (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/common.h (1)
  • ActivationType (22-34)
csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h (1)
  • isGatedActivation (253-256)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (3)
flashinfer/fused_moe/core.py (1)

321-348: Activation type threading looks correct.

End‑to‑end propagation into profiling (run_gemm_profile) and runtime (run_moe) is clean and defaults preserve backward compat.

Also applies to: 401-402, 450-451, 475-476, 841-842

csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu (2)

294-302: Inter-size checks respect gated vs non‑gated activations.

Good use of isGatedActivation(...) to validate FC1 inter size as 2× for gated and 1× for non‑gated (e.g., Relu2).

Also applies to: 467-475


592-593: ActivationType wiring through FFI/profiler looks solid.

runGemmProfile and dispatchers now take the enum; init() passes it to profiler consistently.

If possible, run a quick smoke profile on both Swiglu and Relu2 to ensure distinct tactic selections are exercised.

Also applies to: 644-651, 680-686, 700-706, 719-727

@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Oct 20, 2025

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !87 has been created, and the CI pipeline #36951423 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Contributor

@djns99 djns99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me. If you aren't planning to support the fallback Ampere-style could you add an explicit check. Otherwise, you will just get this error which will likely confuse end users

Unspecified = 6


# Copied from csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/common.h
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any chance we could add this to the python bindings instead of duplicating this manually?
Not sure if there is a good way to do this for enums, but it would be nicer if we could keep these in sync

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with this but there's an issue I'm not sure how to solve:
For the bindings to be present, we need to load the compiled binaries, and this is done lazily from get_cutlass_fused_moe_module(). This means we won't be able to expose it from the python package and then this enum will be redundant.
Do you have a suggestion on how to solve this? If so I'd try and apply to RoutingMethodType as well as I think it has the same issue.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one option could be making these enum classes not just-in-time loaded. we can do that in a future PR.

Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com>
@amirkl94 amirkl94 requested review from djns99 and yzh119 October 21, 2025 06:33
@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #36951423: 1/17 passed

@yzh119 yzh119 merged commit 0c8f234 into flashinfer-ai:main Oct 21, 2025
4 checks passed
@yzh119 yzh119 mentioned this pull request Nov 5, 2025
31 tasks
nzmora-nvidia added a commit to NVIDIA/TensorRT-LLM that referenced this pull request Nov 14, 2025
Update TRTLLM Cutlass MoE kernels with ReLU2 activation.

Nemotron-6 requires ReLU2 (i.e. squared ReLU) MoE activation function.
The PR adds this and adds an API to set the activation function, in general.
The ReLU2 changes are based on this FlashInfer PR: flashinfer-ai/flashinfer#1954.

The PR also updates the Auto Deploy MoE backend for 16-bit and FP8 from
Triton (`torch.ops.auto_deploy.triton_moe_fused`, `torch.ops.auto_deploy.triton_quant_fp8_moe`) to TRTLLM/Cutlass (`torch.ops.auto_deploy.trtllm_moe_fused`, `torch.ops.auto_deploy.trtllm_quant_fp8_moe_fused`).

Signed-off-by: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com>
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
Co-authored-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
zheyuf pushed a commit to zheyuf/TensorRT-LLM that referenced this pull request Nov 19, 2025
…DIA#9011)

Update TRTLLM Cutlass MoE kernels with ReLU2 activation.

Nemotron-6 requires ReLU2 (i.e. squared ReLU) MoE activation function.
The PR adds this and adds an API to set the activation function, in general.
The ReLU2 changes are based on this FlashInfer PR: flashinfer-ai/flashinfer#1954.

The PR also updates the Auto Deploy MoE backend for 16-bit and FP8 from
Triton (`torch.ops.auto_deploy.triton_moe_fused`, `torch.ops.auto_deploy.triton_quant_fp8_moe`) to TRTLLM/Cutlass (`torch.ops.auto_deploy.trtllm_moe_fused`, `torch.ops.auto_deploy.trtllm_quant_fp8_moe_fused`).

Signed-off-by: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com>
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
Co-authored-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
greg-kwasniewski1 pushed a commit to nv-auto-deploy/TensorRT-LLM that referenced this pull request Nov 20, 2025
…DIA#9011)

Update TRTLLM Cutlass MoE kernels with ReLU2 activation.

Nemotron-6 requires ReLU2 (i.e. squared ReLU) MoE activation function.
The PR adds this and adds an API to set the activation function, in general.
The ReLU2 changes are based on this FlashInfer PR: flashinfer-ai/flashinfer#1954.

The PR also updates the Auto Deploy MoE backend for 16-bit and FP8 from
Triton (`torch.ops.auto_deploy.triton_moe_fused`, `torch.ops.auto_deploy.triton_quant_fp8_moe`) to TRTLLM/Cutlass (`torch.ops.auto_deploy.trtllm_moe_fused`, `torch.ops.auto_deploy.trtllm_quant_fp8_moe_fused`).

Signed-off-by: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com>
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
Co-authored-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants