Feature: Support Relu2 activation in fused MoE#1954
Feature: Support Relu2 activation in fused MoE#1954yzh119 merged 4 commits intoflashinfer-ai:mainfrom
Conversation
Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com>
Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com>
Summary of ChangesHello @amirkl94, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a new Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughAdds a Relu2 activation implementation, threads an activation_type parameter through Python and C++ bindings into kernel selection, registers Relu2 in the kernel dispatch, and also adds an explicit runtime error path declaring Relu2 unsupported in the MoeGemm runner dispatch. Changes
Sequence DiagramsequenceDiagram
participant Py as Python (flashinfer/fused_moe/core.py)
participant Bind as C++ Binding (flashinfer_cutlass_fused_moe_sm100_binding.cu)
participant Kernel as Kernel Dispatch (cutlass_fused_moe_kernels.cuh)
participant Act as Relu2 (fused_activations.h)
participant MoeGemm as MoeGemm Dispatch (moe_gemm_template_dispatch.h)
Py->>Bind: run_gemm_profile(..., activation_type)
Bind->>Bind: validate weights & build ActivationParams using activation_type
Bind->>Kernel: request kernel variant for activation_type
alt activation_type == Relu2
Kernel->>Act: use Relu2 adaptor (compute ReLU(value)²)
Act->>Kernel: return activation result
Kernel->>Bind: kernel completes
Bind->>Py: return success
else other activations
Kernel->>Bind: selected other activation path
Bind->>Py: return success
end
note right of MoeGemm: Separate check in MoeGemm runner
Bind->>MoeGemm: invoke moeGemmBiasAct(..., activation_type)
alt MoeGemm receives Relu2
MoeGemm-->>Bind: throw runtime error "Relu2 is not supported."
Bind-->>Py: propagate error
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
🧰 Additional context used🧬 Code graph analysis (1)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h (2)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
🔇 Additional comments (1)
Comment |
There was a problem hiding this comment.
Code Review
This pull request adds support for the Relu2 activation function in the fused MoE kernels. The changes are well-structured, touching both the C++ backend and the Python frontend to plumb through the new activation type. My review identified a critical compilation error due to a typo, some dead code that should be removed, and a couple of minor style/maintainability issues. Once these are addressed, the PR should be in good shape.
| TVM_FFI_ICHECK_EQ(fc1_expert_weights->shape[1], | ||
| fc2_expert_weights->shape[2] * mInnerDimMultiplier) | ||
| << "fc1_expert_weights inter size must be equal to fc2_expert_weights inter size."; |
There was a problem hiding this comment.
The indentation for the TVM_FFI_ICHECK_EQ macro in the else block is inconsistent with the if block. For better readability and to maintain a consistent code style, it should be aligned with the check in the if block.
TVM_FFI_ICHECK_EQ(fc1_expert_weights->shape[1],
fc2_expert_weights->shape[2] * mInnerDimMultiplier)
<< "fc1_expert_weights inter size must be equal to fc2_expert_weights inter size.";
| TVM_FFI_ICHECK_EQ(fc1_expert_weights->shape[1], | ||
| fc2_expert_weights->shape[2] * mInnerDimMultiplier) | ||
| << "fc1_expert_weights inter size must be equal to fc2_expert_weights inter size."; |
There was a problem hiding this comment.
The indentation for the TVM_FFI_ICHECK_EQ macro in the else block is inconsistent with the if block. For better readability and to maintain a consistent code style, it should be aligned with the check in the if block.
TVM_FFI_ICHECK_EQ(fc1_expert_weights->shape[1],
fc2_expert_weights->shape[2] * mInnerDimMultiplier)
<< "fc1_expert_weights inter size must be equal to fc2_expert_weights inter size.";
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (4)
csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh (1)
2375-2383: Fix clang-format violations before merge.GitHub Actions flagged this file for trailing whitespace/formatting. Please run
clang-format(or the project’s formatting hook) over the file so CI can pass.csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue_helpers.h (1)
96-103: Resolve the clang-format findings.CI flagged this header for trailing whitespace/formatting. Please run the project’s formatting step (
clang-format/pre-commit) so the pipeline passes.csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu (2)
589-599: Typo breaks build: AcitvationType → ActivationType in runGemmProfile.The parameter type name is misspelled and will not compile.
- int64_t gemm_idx, int64_t profile_id, bool do_preparation, bool enable_pdl, AcitvationType activation_type) { + int64_t gemm_idx, int64_t profile_id, bool do_preparation, bool enable_pdl, ActivationType activation_type) {
492-497: Missing stream insertion in TVM_FFI_ICHECK_EQ message.The message for swiglu_beta lacks
<<, causing a compile error.- TVM_FFI_ICHECK_EQ(swiglu_beta.value()->shape[0], num_experts_on_rank) - "swiglu_beta must have num_experts_on_rank elements."; + TVM_FFI_ICHECK_EQ(swiglu_beta.value()->shape[0], num_experts_on_rank) + << "swiglu_beta must have num_experts_on_rank elements.";
🧹 Nitpick comments (2)
csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/thread/fused_activations.h (1)
64-79: Drop the unusedmaximuminstance.
maximum<T, PropagateNaN> mx;is never referenced; keeping it will trigger unused-variable warnings with stricter builds. You can remove it and rely onReLu<T>directly.flashinfer/fused_moe/core.py (1)
75-79: Review comment verified; refactoring to complete enum mirroring is recommended.The verification confirms that Swiglu=3 and Relu2=6 correctly match the C++ header. However, the Python enum at
flashinfer/fused_moe/core.py(lines 76-78) is incomplete—it defines only 2 of 9 members present in the C++ definition. The C++ header itself contains a comment on line 21 stating "Note update flashinfer/fused_moe/core.py to match", flagging this as a known sync issue.The Python enum is missing: Gelu (0), Relu (1), Silu (2), Geglu (4), SwigluBias (5), Identity (7), and InvalidType (8). The suggestion to mirror all C++ values or add explicit comments is sound for reducing drift and improving maintainability.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh(1 hunks)csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu(9 hunks)csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/thread/fused_activations.h(1 hunks)csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue_helpers.h(1 hunks)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/common.h(1 hunks)flashinfer/fused_moe/core.py(10 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/common.h (1)
flashinfer/fused_moe/core.py (1)
ActivationType(76-78)
flashinfer/fused_moe/core.py (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/common.h (1)
ActivationType(22-34)
csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h (1)
isGatedActivation(253-256)
🪛 GitHub Actions: pre-commit
csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue_helpers.h
[error] 1-1: Trailing whitespace detected. Fixes applied by pre-commit.
[error] 1-1: clang-format formatting failed. Run 'clang-format' to fix code style issues in this file.
csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh
[error] 1-1: Trailing whitespace detected. Fixes applied by pre-commit.
[error] 1-1: clang-format formatting failed. Run 'clang-format' to fix code style issues in this file.
csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu
[error] 1-1: Trailing whitespace detected. Fixes applied by pre-commit.
[error] 1-1: clang-format formatting failed. Run 'clang-format' to fix code style issues in this file.
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (4)
flashinfer/fused_moe/core.py (2)
312-339: Good threading of activation_type through runner and profile.Constructor storage and forward call propagation look correct. No cache-key change needed since activation only affects call-time behavior.
Also applies to: 392-393
645-652: API default and docs for Relu2.Defaulting to Swiglu and documenting “Relu2 is non‑gated GEMM1” is clear. Propagation in the wrapper call is correct.
Also applies to: 749-752, 832-833
csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu (2)
640-649: Profiling path: passing activation_type into mProfiler->init looks correct.Once the type name typo is fixed, this should wire through cleanly.
Please confirm GemmProfilerBackend::init expects ActivationType at that position in both OSS and non‑OSS builds.
1-10: Commit clang-format changes to the file.Clang-format found and corrected formatting issues in this file. The changes have been applied and need to be committed.
| activation_type: ActivationType = ActivationType.Swiglu, | ||
| ) -> List[torch.Tensor]: |
There was a problem hiding this comment.
Keep fake op signature in sync (adds activation_type).
Public API added activation_type with default. The registered fake op flashinfer::cutlass_fused_moe still lacks this arg and will error under meta/fake dispatch. Add it to the fake op signature.
Apply this diff in the fake op to maintain parity:
@register_fake_op("flashinfer::cutlass_fused_moe")
def _fake_cutlass_fused_moe(
output: torch.Tensor,
input: torch.Tensor,
token_selected_experts: torch.Tensor,
token_final_scales: torch.Tensor,
fc1_expert_weights: torch.Tensor,
fc1_expert_biases: Optional[torch.Tensor],
fc2_expert_weights: torch.Tensor,
fc2_expert_biases: Optional[torch.Tensor],
output_dtype: torch.dtype,
quant_scales: List[torch.Tensor],
input_sf: Optional[torch.Tensor] = None,
swiglu_alpha: Optional[torch.Tensor] = None,
swiglu_beta: Optional[torch.Tensor] = None,
swiglu_limit: Optional[torch.Tensor] = None,
tp_size: int = 1,
tp_rank: int = 0,
ep_size: int = 1,
ep_rank: int = 0,
cluster_size: int = 1,
cluster_rank: int = 0,
enable_alltoall: bool = False,
use_deepseek_fp8_block_scale: bool = False,
use_w4_group_scaling: bool = False,
use_mxfp8_act_scaling: bool = False,
min_latency_mode: bool = False,
tune_max_num_tokens: int = 8192,
enable_pdl: Optional[bool] = None,
+ activation_type: ActivationType = ActivationType.Swiglu,
):Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In flashinfer/fused_moe/core.py around lines 441-442, the public API added the
activation_type: ActivationType = ActivationType.Swiglu parameter but the
registered fake op flashinfer::cutlass_fused_moe signature was not updated;
update the fake op registration/signature to include the activation_type
parameter with the same name and default (ActivationType.Swiglu) so the
meta/fake dispatcher sees a matching signature, and ensure any internal
registration/schema and wrapper call sites use the new parameter name and
default.
Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com>
There was a problem hiding this comment.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu (1)
495-499: Compile error: missing stream insertion in TVM_FFI_ICHECK_EQ.The swiglu_beta check lacks the
<<before the message literal.Apply this diff:
- TVM_FFI_ICHECK_EQ(swiglu_beta.value()->shape[0], num_experts_on_rank) - "swiglu_beta must have num_experts_on_rank elements."; + TVM_FFI_ICHECK_EQ(swiglu_beta.value()->shape[0], num_experts_on_rank) + << "swiglu_beta must have num_experts_on_rank elements.";flashinfer/fused_moe/core.py (1)
561-571: Return contract inverted for min_latency_mode: confirmed bug requires immediate fix.The code allocates and passes three extra tensors to the kernel ONLY when
min_latency_mode=True(line 524-530), yet returns them when False. Lines 561-570 return a single tensor when True (missing the three allocated outputs) and attempts to return four tensors when False (where only one was allocated).The fake op correctly returns 4 items when
min_latency_mode=Trueand 1 item when False, confirming the real op's logic is inverted. While currently latent since the public wrapper raisesNotImplementedErrorfor min_latency_mode, this must be fixed before that mode is implemented.Apply the provided diffs to lines 561-570 (main function) and 605-617 (fake op) to align return behavior with allocation semantics.
♻️ Duplicate comments (3)
flashinfer/fused_moe/core.py (2)
76-87: Keep ActivationType in Python in lockstep with C++ source of truth.Enum values currently match common.h, but this local copy can drift. Consider auto‑generating from the C++ header at build time or adding a cheap import‑time check that compares numeric values returned from a tiny C++ helper. At minimum, document the sync contract next to this enum.
572-601: Fake op signature missing activation_type (breaks meta/fake dispatch).Add the new parameter with same name/default as the real op. Otherwise, fake/meta kernels will error on kwargs mismatch.
Apply this diff:
@register_fake_op("flashinfer::cutlass_fused_moe") def _fake_cutlass_fused_moe( output: torch.Tensor, input: torch.Tensor, token_selected_experts: torch.Tensor, token_final_scales: torch.Tensor, fc1_expert_weights: torch.Tensor, fc1_expert_biases: Optional[torch.Tensor], fc2_expert_weights: torch.Tensor, fc2_expert_biases: Optional[torch.Tensor], output_dtype: torch.dtype, quant_scales: List[torch.Tensor], input_sf: Optional[torch.Tensor] = None, swiglu_alpha: Optional[torch.Tensor] = None, swiglu_beta: Optional[torch.Tensor] = None, swiglu_limit: Optional[torch.Tensor] = None, tp_size: int = 1, tp_rank: int = 0, ep_size: int = 1, ep_rank: int = 0, cluster_size: int = 1, cluster_rank: int = 0, enable_alltoall: bool = False, use_deepseek_fp8_block_scale: bool = False, use_w4_group_scaling: bool = False, use_mxfp8_act_scaling: bool = False, min_latency_mode: bool = False, tune_max_num_tokens: int = 8192, enable_pdl: Optional[bool] = None, + activation_type: ActivationType = ActivationType.Swiglu, ):csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu (1)
365-367: Quant-scale shape checks still assume gated (×2) for FC1; plumb activation to getQuantParams.Shape validations for FC1 block scales hardcode
inter_size * 2. That rejects valid Relu2 (non‑gated) layouts and mis-sizes workspaces. Pass ActivationType into getQuantParams and gate the multiplier.Apply these diffs:
- Thread activation into calls:
- auto const quant_params = - getQuantParams(num_experts_on_rank, hidden_size, inter_size, quant_scales); + auto const quant_params = + getQuantParams(num_experts_on_rank, hidden_size, inter_size, quant_scales, base_activation_type);- auto const quant_params = - getQuantParams(num_experts_on_rank, hidden_size, inter_size, quant_scales); + auto const quant_params = + getQuantParams(num_experts_on_rank, hidden_size, inter_size, quant_scales, base_activation_type);
- Update signature and gate multiplier:
- kernels::QuantParams getQuantParams(int64_t num_experts_on_rank, int64_t hidden_size, - int64_t inter_size, - Optional<Array<Tensor>> quant_scales) const { + kernels::QuantParams getQuantParams(int64_t num_experts_on_rank, int64_t hidden_size, + int64_t inter_size, Optional<Array<Tensor>> quant_scales, + ActivationType base_activation_type) const { + bool const is_gated = isGatedActivation(base_activation_type);
- Fix FC1 shape checks (examples shown for each block):
W4A8_MXFP4_FP8:
- TVM_FFI_ICHECK( + TVM_FFI_ICHECK( fc1_weight_block->shape[0] == num_experts_on_rank && fc1_weight_block->shape[1] == TmaWarpSpecializedGroupedGemmInput::alignToSfDim( - inter_size, TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX) * - 2 && + inter_size, TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX) * + (is_gated ? 2 : 1) && ... - << "fc1 weight block size must be (num_experts_on_rank, inter_size * 2, hidden_size // 4 " + << "fc1 weight block size must be (num_experts_on_rank, inter_size * (is_gated?2:1), hidden_size // 4 " "// block_scale_vector_size)";W4A8_MXFP4_MXFP8:
- TVM_FFI_ICHECK( + TVM_FFI_ICHECK( fc1_weight_block->shape[0] == num_experts_on_rank && fc1_weight_block->shape[1] == TmaWarpSpecializedGroupedGemmInput::alignToSfDim( - inter_size, TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX) * - 2 && + inter_size, TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX) * + (is_gated ? 2 : 1) && ... - << "fc1 weight block size must be (num_experts_on_rank, inter_size * 2, hidden_size // 4 " + << "fc1 weight block size must be (num_experts_on_rank, inter_size * (is_gated?2:1), hidden_size // 4 " "// block_scale_vector_size)";NVFP4:
- TVM_FFI_ICHECK( + TVM_FFI_ICHECK( fc1_weight_block->shape[0] == num_experts_on_rank && fc1_weight_block->shape[1] == TmaWarpSpecializedGroupedGemmInput::alignToSfDim( - inter_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentNVFP4) * - 2 && + inter_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentNVFP4) * + (is_gated ? 2 : 1) && ... - << "fc1 weight block size must be (num_experts_on_rank, inter_size * 2, hidden_size // 4 " + << "fc1 weight block size must be (num_experts_on_rank, inter_size * (is_gated?2:1), hidden_size // 4 " "// block_scale_vector_size)";Also applies to: 542-544, 807-813, 871-882, 930-941, 1001-1011
🧹 Nitpick comments (1)
flashinfer/fused_moe/core.py (1)
655-656: Docstring/type hint don’t match actual return shape.Signature and docs promise torch.Tensor, but the op returns a list when min_latency_mode flips. After fixing the inversion, either:
- document Union[torch.Tensor, List[torch.Tensor]], or
- normalize to always return a torch.Tensor from this wrapper.
Pick one and update annotations/docs accordingly.
Also applies to: 759-761
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh(1 hunks)csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu(9 hunks)csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/thread/fused_activations.h(1 hunks)flashinfer/fused_moe/core.py(10 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/thread/fused_activations.h
- csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh
🧰 Additional context used
🧬 Code graph analysis (2)
flashinfer/fused_moe/core.py (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/common.h (1)
ActivationType(22-34)
csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h (1)
isGatedActivation(253-256)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (3)
flashinfer/fused_moe/core.py (1)
321-348: Activation type threading looks correct.End‑to‑end propagation into profiling (run_gemm_profile) and runtime (run_moe) is clean and defaults preserve backward compat.
Also applies to: 401-402, 450-451, 475-476, 841-842
csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu (2)
294-302: Inter-size checks respect gated vs non‑gated activations.Good use of isGatedActivation(...) to validate FC1 inter size as 2× for gated and 1× for non‑gated (e.g., Relu2).
Also applies to: 467-475
592-593: ActivationType wiring through FFI/profiler looks solid.runGemmProfile and dispatchers now take the enum; init() passes it to profiler consistently.
If possible, run a quick smoke profile on both Swiglu and Relu2 to ensure distinct tactic selections are exercised.
Also applies to: 644-651, 680-686, 700-706, 719-727
|
/bot run |
djns99
left a comment
There was a problem hiding this comment.
This looks good to me. If you aren't planning to support the fallback Ampere-style could you add an explicit check. Otherwise, you will just get this error which will likely confuse end users
| Unspecified = 6 | ||
|
|
||
|
|
||
| # Copied from csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/common.h |
There was a problem hiding this comment.
Any chance we could add this to the python bindings instead of duplicating this manually?
Not sure if there is a good way to do this for enums, but it would be nicer if we could keep these in sync
There was a problem hiding this comment.
I agree with this but there's an issue I'm not sure how to solve:
For the bindings to be present, we need to load the compiled binaries, and this is done lazily from get_cutlass_fused_moe_module(). This means we won't be able to expose it from the python package and then this enum will be redundant.
Do you have a suggestion on how to solve this? If so I'd try and apply to RoutingMethodType as well as I think it has the same issue.
There was a problem hiding this comment.
one option could be making these enum classes not just-in-time loaded. we can do that in a future PR.
Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com>
|
[FAILED] Pipeline #36951423: 1/17 passed |
Update TRTLLM Cutlass MoE kernels with ReLU2 activation. Nemotron-6 requires ReLU2 (i.e. squared ReLU) MoE activation function. The PR adds this and adds an API to set the activation function, in general. The ReLU2 changes are based on this FlashInfer PR: flashinfer-ai/flashinfer#1954. The PR also updates the Auto Deploy MoE backend for 16-bit and FP8 from Triton (`torch.ops.auto_deploy.triton_moe_fused`, `torch.ops.auto_deploy.triton_quant_fp8_moe`) to TRTLLM/Cutlass (`torch.ops.auto_deploy.trtllm_moe_fused`, `torch.ops.auto_deploy.trtllm_quant_fp8_moe_fused`). Signed-off-by: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com> Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> Co-authored-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
…DIA#9011) Update TRTLLM Cutlass MoE kernels with ReLU2 activation. Nemotron-6 requires ReLU2 (i.e. squared ReLU) MoE activation function. The PR adds this and adds an API to set the activation function, in general. The ReLU2 changes are based on this FlashInfer PR: flashinfer-ai/flashinfer#1954. The PR also updates the Auto Deploy MoE backend for 16-bit and FP8 from Triton (`torch.ops.auto_deploy.triton_moe_fused`, `torch.ops.auto_deploy.triton_quant_fp8_moe`) to TRTLLM/Cutlass (`torch.ops.auto_deploy.trtllm_moe_fused`, `torch.ops.auto_deploy.trtllm_quant_fp8_moe_fused`). Signed-off-by: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com> Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> Co-authored-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
…DIA#9011) Update TRTLLM Cutlass MoE kernels with ReLU2 activation. Nemotron-6 requires ReLU2 (i.e. squared ReLU) MoE activation function. The PR adds this and adds an API to set the activation function, in general. The ReLU2 changes are based on this FlashInfer PR: flashinfer-ai/flashinfer#1954. The PR also updates the Auto Deploy MoE backend for 16-bit and FP8 from Triton (`torch.ops.auto_deploy.triton_moe_fused`, `torch.ops.auto_deploy.triton_quant_fp8_moe`) to TRTLLM/Cutlass (`torch.ops.auto_deploy.trtllm_moe_fused`, `torch.ops.auto_deploy.trtllm_quant_fp8_moe_fused`). Signed-off-by: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com> Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> Co-authored-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
📌 Description
Added support for Relu2 activation in cutlass fp8 FusedMoE path.
Relu2(x) = Relu(x)^2.Validated this works correctly on H100 and B200.
Summary by CodeRabbit