feat: Add b12x_fused_moe / B12xMoEWrapper SM120 APIs with micro kernel and ReLU2#3080
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdded gated (SiLU) vs non‑gated (ReLU²) activation support across benchmarks and fused MoE CuTe‑DSL kernels; introduced SM12x b12x functional API and wrapper, a tiny‑decode micro backend with routing compaction, activation‑aware kernel compilation/storage, weight/layout/quantization changes, and new tests. Changes
Sequence Diagram(s)sequenceDiagram
participant Caller as Caller
participant Dispatch as Dispatch Layer
participant Compact as Triton Compact
participant Cache as Kernel Cache
participant Micro as Micro Kernel
participant Static as Static Kernel
participant Dynamic as Dynamic Kernel
participant Activation as Activation Func
Caller->>Dispatch: submit token_selected_experts, weights, scales, activation
Dispatch->>Dispatch: compute routed_rows = num_tokens * top_k
Dispatch->>Dispatch: select backend (micro/static/dynamic) using cutovers
alt Micro Path
Dispatch->>Compact: compact_topk_ids(topk_ids)
Compact-->>Dispatch: compact_ids, active_expert_count, weight_expert_ids
Dispatch->>Cache: lookup/compile micro kernel (activation, mac_override)
Cache-->>Dispatch: micro kernel
Dispatch->>Micro: launch(compact_ids, activation, weights, scales)
Micro->>Activation: apply SiLU or ReLU²
Micro->>Caller: write outputs
else Static Path
Dispatch->>Cache: lookup/compile static kernel (activation)
Cache-->>Dispatch: static kernel
Dispatch->>Static: launch(topk_ids, activation, weights, scales)
Static->>Activation: apply SiLU or ReLU²
Static->>Caller: write outputs
else Dynamic Path
Dispatch->>Dynamic: launch dynamic kernel (activation)
Dynamic->>Activation: apply activation variant
Dynamic->>Caller: write outputs
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request adds support for non-gated MoE activations (ReLU2) and introduces a specialized micro-kernel for small decode batches on Blackwell SM120/SM121 architectures. The changes include updates to the static and dynamic CuTe DSL kernels, a new Triton-based ID compaction pre-pass, and updated benchmarking and testing utilities. Feedback indicates that the "moe_micro_kernel.py" file is missing from the PR and identifies a potential out-of-bounds risk when slicing the workspace buffer for compact IDs.
| from .triton_compact import compact_topk_ids as _triton_compact_topk_ids | ||
|
|
||
| # Run Triton pre-pass to compact global expert IDs to dense local indices | ||
| compact_ids = workspace.compact_topk_ids[: flat_ids.numel()] |
There was a problem hiding this comment.
There is a potential out-of-bounds risk if flat_ids.numel() exceeds state_E. While the micro kernel path is currently restricted to routed_rows <= 40, the workspace allocation for compact_topk_ids uses state_E. It would be safer to ensure that the slice does not exceed the allocated size of the workspace buffer, or add an explicit check.
There was a problem hiding this comment.
Good catch on the defensive coding. In practice this can't overflow — compact_topk_ids is sized [state_E] (typically 256-512) while flat_ids.numel() is at most 40 on the micro path (the cutover threshold). But the invariant should be explicit. Added an assertion in the next commit
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
benchmarks/routines/moe.py (1)
1790-1813:⚠️ Potential issue | 🟡 MinorPass
is_gatedinto the FP8 bandwidth model too.TFLOPS now distinguishes gated vs non-gated activations, but both FP8 bandwidth calls still rely on the default path. ReLU2 runs will therefore report inconsistent bandwidth.
🛠️ Suggested fix
tb_per_sec = calculate_moe_kernel_bandwidth( num_tokens, hidden_size, intermediate_size, num_experts, top_k, median_time, input_dtype, weight_dtype, input_format="fp8", weight_format="fp8", routing_logits_dtype=routing_logits.dtype, active_experts=int(selected_experts.unique().numel()), verbose=args.verbose, + is_gated=args.activation_type in (ActivationType.Swiglu, ActivationType.Geglu), )tb_per_sec = calculate_moe_kernel_bandwidth( num_tokens, hidden_size, intermediate_size, num_experts, top_k, median_time, input_dtype, weight_dtype, input_format="fp8", weight_format="fp8", routing_logits_dtype=routing_logits.dtype, active_experts=int(selected_experts.unique().numel()), verbose=args.verbose, + is_gated=args.activation_type in (ActivationType.Swiglu, ActivationType.Geglu), )Also applies to: 2025-2048
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/routines/moe.py` around lines 1790 - 1813, The FP8 bandwidth call is missing the is_gated flag, causing gated vs non-gated activations to report inconsistent bandwidth; update the calculate_moe_kernel_bandwidth invocation(s) to pass the same is_gated boolean used for calculate_moe_tflops (e.g., is_gated=args.activation_type in (ActivationType.Swiglu, ActivationType.Geglu)) so both calculate_moe_tflops(...) and calculate_moe_kernel_bandwidth(...) receive the same gating hint (apply the same change to the other occurrences around the later block that mirrors this code).flashinfer/fused_moe/cute_dsl/fused_moe.py (1)
362-394:⚠️ Potential issue | 🟠 MajorReject
relu2outside the SM120/SM121 path.These new public parameters are only honored in the SM120 branch. The fallback path still goes through
_moe_core_impl(), which hard-wires the SwiGLU fusion helper, soactivation_type="relu2"on SM100/SM103 can run the wrong math or hit mismatched FC1 shapes instead of failing fast.🛠️ Suggested guard
@@ - self.activation_type = activation_type + if activation_type not in {"silu", "relu2"}: + raise ValueError(f"Unsupported activation_type: {activation_type!r}") + self.activation_type = activation_type @@ major, minor = torch.cuda.get_device_capability(device) self._is_sm120 = major == 12 + if activation_type != "silu" and not self._is_sm120: + raise ValueError( + "activation_type='relu2' is only supported on SM120/SM121" + )def cute_dsl_fused_moe_nvfp4( @@ - if num_local_experts is None: + if activation_type not in {"silu", "relu2"}: + raise ValueError(f"Unsupported activation_type: {activation_type!r}") + + if num_local_experts is None: num_local_experts = num_experts @@ major, _ = torch.cuda.get_device_capability(x.device) if major == 12: ... + elif activation_type != "silu": + raise ValueError( + "activation_type='relu2' is only supported on SM120/SM121" + )Also applies to: 827-916
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/fused_moe/cute_dsl/fused_moe.py` around lines 362 - 394, The constructor accepts activation_type but the non-SM120/SM121 path still calls _moe_core_impl which assumes SwiGLU; add a runtime guard in the initializer (or immediately before dispatch to _moe_core_impl) that checks activation_type and the detected GPU SM version and either raise a clear error or restrict allowed values when SM < 120 (e.g., if activation_type == "relu2" and not on SM120/121, raise ValueError). Update the dispatch code path that calls _moe_core_impl (and any fallback branches referenced around the alternate implementation) to enforce this same check so relu2 is only honored on the SM120/SM121 branch and cannot silently run with the SwiGLU helper.flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py (1)
138-185:⚠️ Potential issue | 🔴 CriticalSize
compact_topk_idsfor routed rows, not experts.Line 824 slices this buffer to
flat_ids.numel()(num_tokens * top_k), but the workspace only allocatesstate_Eentries. Any micro launch with more routed rows than local experts will write past the end of the buffer.🛠️ Suggested fix
- compact_topk_ids: torch.Tensor # [state_E] int32, for micro kernel pre-pass + compact_topk_ids: torch.Tensor # [max_rows] int32, for micro kernel pre-pass @@ - compact_topk_ids=torch.empty(state_E, dtype=torch.int32, device=device), + compact_topk_ids=torch.empty(max_rows, dtype=torch.int32, device=device),🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py` around lines 138 - 185, compact_topk_ids is currently allocated with length state_E but is later sliced to flat_ids.numel() (num_tokens * top_k) in the micro-kernel pre-pass, causing out-of-bounds writes when num_tokens > state_E; in allocate_sm120_static_workspace change the compact_topk_ids allocation in Sm120StaticMoEWorkspace to have capacity for the worst-case routed rows times top-k (e.g. torch.empty(state_E * max_rows * num_topk, dtype=torch.int32, device=device) or at minimum torch.empty(max_rows * state_E * num_topk, ...)) so flat_ids.numel() can always fit, and keep references to compact_topk_ids, allocate_sm120_static_workspace, Sm120StaticMoEWorkspace, num_topk, max_rows, and state_E to locate the change.
🧹 Nitpick comments (1)
flashinfer/fused_moe/cute_dsl/blackwell_sm12x/triton_compact.py (1)
76-86: Add a defensive size guard for micro-only usage.This kernel is O(BLOCK²) in a single program; adding an explicit upper bound makes accidental large launches fail fast with a clear message.
Possible guardrail
block = triton.next_power_of_2(total_pairs) + if block > 256: + raise ValueError( + f"compact_topk_ids is intended for micro batches; got total_pairs={total_pairs}" + ) num_warps = 1 if block <= 16 else 2🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/triton_compact.py` around lines 76 - 86, Add a defensive size guard before launching _compact_topk_ids_kernel to prevent accidental large O(BLOCK²) launches: compute block = triton.next_power_of_2(total_pairs) as you do, then check against a small hard limit (e.g. MAX_BLOCK = 64 or 128) and/or a MAX_PAIRS derived limit and raise a clear RuntimeError if block > MAX_BLOCK (include block and total_pairs in the message). Keep the existing num_warps logic and kernel args (_compact_topk_ids_kernel, topk_ids, compact_topk_ids, weight_expert_ids, active_expert_count, total_pairs, BLOCK=block, num_warps=num_warps) unchanged; just insert the guard using the same symbols so oversized launches fail fast with a descriptive error.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@benchmarks/routines/moe.py`:
- Around line 1367-1371: The code silently maps unsupported ActivationType
values to "silu"; change the logic to validate args.activation_type against the
supported mapping instead of defaulting. Use the _ACT_STR dict to look up
activation_str and if the activation_type is not present raise a clear exception
(e.g., ValueError) mentioning the unsupported ActivationType and listing
supported keys; also compute is_gated from ActivationType.Geglu and
ActivationType.Swiglu as before but ensure Geglu is rejected if not in _ACT_STR
so it cannot silently run the SiLU kernel.
In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.py`:
- Around line 539-541: Comments describing w13 tile ordering conflict with
actual code: update the inline comments around the w13 TMA descriptor creation
(the lines that mention "Gate tiles at N=..." and "up tiles at N=...") to
reflect the actual ordering used by the code path where gate_slice_idx =
intermediate_slice + gate_tile_cnt (i.e., up tiles occupy the first half of
N-tiles and gate tiles the second half). Locate the call to
self._dense_cls._make_tma_atoms_and_tensors that produces tma_b_w13/gB_w13 and
any other similar comments (also around the other occurrence near lines
~1168-1171) and change the wording so it states "Up tiles at N=0..I_tp/tile_N-1,
Gate tiles at N=I_tp/tile_N..2*I_tp/tile_N-1" or equivalent that matches the
gate_slice_idx logic.
In `@tests/moe/test_cute_dsl_fused_moe.py`:
- Around line 1093-1096: Replace the custom skip gating that defines
sm120_cuda13 (which currently uses is_sm120_family() and _has_cuda_13()) with
the repository-standard capability checks from flashinfer.utils or the API
capability method; specifically, remove is_sm120_family()/_has_cuda_13() and use
the appropriate flashinfer.utils helper (e.g., is_sm120_supported() or analogous
is_sm90a_supported()/is_sm100a_supported()) or call
api_name.is_compute_capability_supported(cc) to decide the skip. Ensure the new
marker still uses pytest.mark.skipif(...) with a descriptive reason string
indicating the required SM/CUDA capability.
---
Outside diff comments:
In `@benchmarks/routines/moe.py`:
- Around line 1790-1813: The FP8 bandwidth call is missing the is_gated flag,
causing gated vs non-gated activations to report inconsistent bandwidth; update
the calculate_moe_kernel_bandwidth invocation(s) to pass the same is_gated
boolean used for calculate_moe_tflops (e.g., is_gated=args.activation_type in
(ActivationType.Swiglu, ActivationType.Geglu)) so both calculate_moe_tflops(...)
and calculate_moe_kernel_bandwidth(...) receive the same gating hint (apply the
same change to the other occurrences around the later block that mirrors this
code).
In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py`:
- Around line 138-185: compact_topk_ids is currently allocated with length
state_E but is later sliced to flat_ids.numel() (num_tokens * top_k) in the
micro-kernel pre-pass, causing out-of-bounds writes when num_tokens > state_E;
in allocate_sm120_static_workspace change the compact_topk_ids allocation in
Sm120StaticMoEWorkspace to have capacity for the worst-case routed rows times
top-k (e.g. torch.empty(state_E * max_rows * num_topk, dtype=torch.int32,
device=device) or at minimum torch.empty(max_rows * state_E * num_topk, ...)) so
flat_ids.numel() can always fit, and keep references to compact_topk_ids,
allocate_sm120_static_workspace, Sm120StaticMoEWorkspace, num_topk, max_rows,
and state_E to locate the change.
In `@flashinfer/fused_moe/cute_dsl/fused_moe.py`:
- Around line 362-394: The constructor accepts activation_type but the
non-SM120/SM121 path still calls _moe_core_impl which assumes SwiGLU; add a
runtime guard in the initializer (or immediately before dispatch to
_moe_core_impl) that checks activation_type and the detected GPU SM version and
either raise a clear error or restrict allowed values when SM < 120 (e.g., if
activation_type == "relu2" and not on SM120/121, raise ValueError). Update the
dispatch code path that calls _moe_core_impl (and any fallback branches
referenced around the alternate implementation) to enforce this same check so
relu2 is only honored on the SM120/SM121 branch and cannot silently run with the
SwiGLU helper.
---
Nitpick comments:
In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/triton_compact.py`:
- Around line 76-86: Add a defensive size guard before launching
_compact_topk_ids_kernel to prevent accidental large O(BLOCK²) launches: compute
block = triton.next_power_of_2(total_pairs) as you do, then check against a
small hard limit (e.g. MAX_BLOCK = 64 or 128) and/or a MAX_PAIRS derived limit
and raise a clear RuntimeError if block > MAX_BLOCK (include block and
total_pairs in the message). Keep the existing num_warps logic and kernel args
(_compact_topk_ids_kernel, topk_ids, compact_topk_ids, weight_expert_ids,
active_expert_count, total_pairs, BLOCK=block, num_warps=num_warps) unchanged;
just insert the guard using the same symbols so oversized launches fail fast
with a descriptive error.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 5e3f31a2-4325-47e8-9db8-0acd0b10bef6
📒 Files selected for processing (11)
benchmarks/routines/moe.pybenchmarks/routines/moe_utils.pyflashinfer/cute_dsl/fp4_common.pyflashinfer/fused_moe/cute_dsl/blackwell_sm12x/__init__.pyflashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.pyflashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.pyflashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_micro_kernel.pyflashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_static_kernel.pyflashinfer/fused_moe/cute_dsl/blackwell_sm12x/triton_compact.pyflashinfer/fused_moe/cute_dsl/fused_moe.pytests/moe/test_cute_dsl_fused_moe.py
|
/bot run |
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@benchmarks/routines/moe.py`:
- Around line 2030-2037: The bandwidth computation is still using the default
gated accounting while TFLOPS now selects non-gated accounting for some
activations; update the call to calculate_moe_kernel_bandwidth to pass the same
is_gated boolean used for calculate_moe_tflops (i.e.,
is_gated=args.activation_type in (ActivationType.Swiglu, ActivationType.Geglu)
or the expression that produced non-gated for Relu2) so both TFLOPS and kernel
bandwidth use the same activation-aware gating flag (refer to
calculate_moe_tflops and calculate_moe_kernel_bandwidth to locate the calls).
- Around line 1795-1802: The TFLOPS call uses args.activation_type to set
is_gated but this routine still constructs gated FC1 tensors and
run_fp8_block_moe never receives an activation flag, so reported TFLOPS can
diverge from the executed kernel; fix by not switching the is_gated flag based
on args.activation_type here — either hard-code is_gated=True (gated-only path)
or derive is_gated from the same gated-only indicator used when building tensors
(e.g., the 2 * intermediate_size gated FC1 logic) and/or update
run_fp8_block_moe to accept and forward an activation_type so activation-based
toggles are consistent with calculate_moe_tflops; reference
calculate_moe_tflops, run_fp8_block_moe, args.activation_type and the gated FC1
construction (2 * intermediate_size) when making the change.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 43a02c6d-a59b-416a-b805-d1e691a5397e
📒 Files selected for processing (6)
benchmarks/routines/moe.pyflashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.pyflashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.pyflashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_micro_kernel.pyflashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_static_kernel.pytests/moe/test_cute_dsl_fused_moe.py
🚧 Files skipped from review as they are similar to previous changes (2)
- tests/moe/test_cute_dsl_fused_moe.py
- flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py
|
/bot stop |
|
/bot run |
|
The GitLab CI pipeline #48713073 has been cancelled. |
|
/bot stop |
|
The GitLab CI pipeline #48714380 has been cancelled. |
|
/bot run |
| @@ -0,0 +1,1334 @@ | |||
| """ | |||
| Copyright (c) 2025 by FlashInfer team. | |||
There was a problem hiding this comment.
Will try to capture in a subsequent PR that also addresses non-blocking mm_fp4(backend='b12x') comments to avoid needing to re-trigger the CI again 😅
|
@flashinfer-bot run |
|
/bot run |
|
[FAILED] Pipeline #49005075: 15/20 passed |
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> Follow-up to #3051 (`backend="b12x"` for `mm_fp4` on SM120) and #3080 (`b12x_fused_moe` / `B12xMoEWrapper` SM120 APIs) addressing four reviewer comments that landed after merge. No public API changes; no kernel behavior changes. - **Copyright**: bump `tests/moe/test_b12x_fused_moe.py` to 2026. - **Benchmark split**: new `b12x_fused_moe` routine (SM120/121, BF16 input, SwiGLU + ReLU²); `cute_dsl_fp4_block_scale_moe` is now SM100/103-only. Aligns with the `B12xMoEWrapper` / `CuteDslMoEWrapper` Python API split. - **Cache SM count**: replace a hot-path `torch.cuda.get_device_properties(...).multi_processor_count` in the `b12x` FP4 GEMM runner with the cached `get_device_sm_count()` helper. - **Rename for provenance**: `dense_blockscaled_gemm_sm120.py` → `dense_blockscaled_gemm_sm120_b12x.py` and `Sm120BlockScaledDenseGemmKernel` → `Sm120B12xBlockScaledDenseGemmKernel` (via `git mv`, 6 import sites updated). `backend="b12x"` string unchanged. ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added new `b12x_fused_moe` benchmark routine for NVFP4 MoE inference with support for both SwiGLU and ReLU2 activation types. * Extended Blackwell architecture support with updated kernel implementations. * **Documentation** * Updated benchmark samples with new `b12x_fused_moe` test configurations. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
📌 Description
Summary
New SM120/SM121 MoE APIs (
b12x_fused_moe,B12xMoEWrapper) with:max(0,x)²) for non-gated MoE (Nemotron-Super) across all three SM120 kernel backends (micro, static, dynamic)cutlass_fused_moeandcute_dsl_fp4_block_scale_moeroutines, with corrected TFLOPS/bandwidth calculations for non-gated activationsb12x_fused_moe, SM100 keepscute_dsl_fused_moe_nvfp4API separation
cute_dsl_fused_moe_nvfp4(FP4 input)CuteDslMoEWrapperb12x_fused_moe(bf16 input)B12xMoEWrapperThe SM100 APIs (
cute_dsl_fused_moe_nvfp4,CuteDslMoEWrapper) are restored to SM100-only scope — no SM120 dispatch, noactivation_typeparameter.Micro kernel
Ported from b12x. Selected automatically when
routed_rows ≤ 20(top_k=1) or≤ 40(top_k>1). Key optimizations vs the static kernel:all_rows_uniquefast path: whennum_tokens=1and every expert is unique, skips atomic row counting and uses O(1) work-tile assignmentReLU2 activation
Added
activationparameter ("silu"default,"relu2") to all SM120 kernel classes viaself.is_gatedcompile-time branching (cutlass.const_expr):StorageGated(3 pipelines, gate+up buffers) vsStorageRelu2(2 pipelines, single FC1 buffer)silu(gate) * upvsrelu(x)²Exposed through
activation_typeparameter onCuteDslMoEWrapperandcute_dsl_fused_moe_nvfp4APIs.API usage
Functional
Wrapper (CUDA graph compatible)
Example micro benchmarks
🔍 Related Issues
#3013
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Tests