Skip to content

feat: Add b12x_fused_moe / B12xMoEWrapper SM120 APIs with micro kernel and ReLU2#3080

Merged
bkryu merged 14 commits into
flashinfer-ai:mainfrom
bkryu:b12x_micro_kernel
Apr 20, 2026
Merged

feat: Add b12x_fused_moe / B12xMoEWrapper SM120 APIs with micro kernel and ReLU2#3080
bkryu merged 14 commits into
flashinfer-ai:mainfrom
bkryu:b12x_micro_kernel

Conversation

@bkryu

@bkryu bkryu commented Apr 15, 2026

Copy link
Copy Markdown
Collaborator

📌 Description

Summary

New SM120/SM121 MoE APIs (b12x_fused_moe, B12xMoEWrapper) with:

  • Micro kernel for tiny decode batches (≤20-40 routed rows) on SM120/SM121, with Triton routing compaction pre-pass and MAC tuning ladder
  • ReLU2 activation (max(0,x)²) for non-gated MoE (Nemotron-Super) across all three SM120 kernel backends (micro, static, dynamic)
  • Benchmark ReLU2 support for both cutlass_fused_moe and cute_dsl_fp4_block_scale_moe routines, with corrected TFLOPS/bandwidth calculations for non-gated activations
  • Clean API separation: SM120 uses b12x_fused_moe, SM100 keeps cute_dsl_fused_moe_nvfp4

API separation

GPU Functional API Wrapper API
SM100/SM103 cute_dsl_fused_moe_nvfp4 (FP4 input) CuteDslMoEWrapper
SM120/SM121 b12x_fused_moe (bf16 input) B12xMoEWrapper

The SM100 APIs (cute_dsl_fused_moe_nvfp4, CuteDslMoEWrapper) are restored to SM100-only scope — no SM120 dispatch, no activation_type parameter.

Micro kernel

Ported from b12x. Selected automatically when routed_rows ≤ 20 (top_k=1) or ≤ 40 (top_k>1). Key optimizations vs the static kernel:

  • Triton compact pre-pass: remaps global expert IDs to dense local indices, eliminating CAS-based expert discovery inside the kernel
  • all_rows_unique fast path: when num_tokens=1 and every expert is unique, skips atomic row counting and uses O(1) work-tile assignment
  • MAC tuning ladder: per-routed-row optimal cluster counts from b12x decode profiling, capped against hardware SM count to prevent deadlocks

ReLU2 activation

Added activation parameter ("silu" default, "relu2") to all SM120 kernel classes via self.is_gated compile-time branching (cutlass.const_expr):

  • Storage: StorageGated (3 pipelines, gate+up buffers) vs StorageRelu2 (2 pipelines, single FC1 buffer)
  • FC1: dual GEMM (gate+up) for SiLU vs single GEMM for ReLU2
  • Activation: silu(gate) * up vs relu(x)²
  • DMA: up-projection TMA loads eliminated for ReLU2

Exposed through activation_type parameter on CuteDslMoEWrapper and cute_dsl_fused_moe_nvfp4 APIs.

API usage

Functional

from flashinfer import b12x_fused_moe  

output = b12x_fused_moe(
    x=hidden_states_bf16,       # bf16 input (kernel fuses quantization)                                                                          
    w1_weight=w1_fp4, w1_weight_sf=w1_sf, w1_alpha=w1_alpha,                                                                                      
    fc2_input_scale=fc2_scale,                                                                                                                    
    w2_weight=w2_fp4, w2_weight_sf=w2_sf, w2_alpha=w2_alpha,                                                                                      
    token_selected_experts=topk_ids,                                                                                                              
    token_final_scales=topk_weights,                                                                                                              
    num_experts=512, top_k=22,                                                                                                                    
    activation="relu2",  # or "silu" (default)                                                                                                    
)                                                                                                                                                 

Wrapper (CUDA graph compatible)

from flashinfer import B12xMoEWrapper

moe = B12xMoEWrapper(                                                                                                                             
    num_experts=512, top_k=22,
    hidden_size=1024, intermediate_size=2688,                                                                                                     
    use_cuda_graph=True, activation="relu2",                                            
)                                                                                                                                                 
output = moe.run(x=hidden_states_bf16, ...)

Example micro benchmarks

# b12x cute dsl MoE for 1-token Nemotron 3 Super Size
python benchmarks/flashinfer_benchmark.py --routine cute_dsl_fp4_block_scale_moe --activation-type Relu2 --num_tokens 1 --hidden_size 1024 --intermediate_size 2688 --num_experts 512 --top_k 22 --use_cuda_events --num_iters 50
# Equivalent cutlass_fused_moe benchmark
python benchmarks/flashinfer_benchmark.py --routine cutlass_fused_moe --cutlass_variant nvfp4 --activation-type Relu2 --num_tokens 1 --hidden_size 1024 --intermediate_size 2688 --num_experts 512 --top_k 22 --quantized_input --use_cuda_events --num_iters 50

🔍 Related Issues

#3013

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Non‑gated ReLU2 activation and dual gated/non‑gated FC1 layouts for MoE; activation selectable at runtime.
    • New micro‑kernel backend plus routing‑ID compaction for improved single‑token/small‑batch performance.
    • SM12x (b12x) fused‑MoE functional API and CUDA‑graph‑friendly wrapper exported for SM12x workflows; runtime maps activations to kernel implementations.
    • CuTe‑DSL helpers added to support ReLU2 + FP4 quantization.
  • Tests

    • End‑to‑end tests for ReLU2, gated vs non‑gated flows, micro‑kernel paths, CUDA graph replay, and FP4 numerical agreement.

@coderabbitai

coderabbitai Bot commented Apr 15, 2026

Copy link
Copy Markdown
Contributor

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Added gated (SiLU) vs non‑gated (ReLU²) activation support across benchmarks and fused MoE CuTe‑DSL kernels; introduced SM12x b12x functional API and wrapper, a tiny‑decode micro backend with routing compaction, activation‑aware kernel compilation/storage, weight/layout/quantization changes, and new tests.

Changes

Cohort / File(s) Summary
Benchmarks & Utilities
benchmarks/routines/moe.py, benchmarks/routines/moe_utils.py
Propagated activation / is_gated into benchmarks and TFLOPs/bandwidth models; adjusted w1_rows/w1_cols accounting and benchmark wiring for gated vs non‑gated layouts.
FP4 Helpers
flashinfer/cute_dsl/fp4_common.py
Added relu2_16 and relu2_quantize_block_fp4 JIT helpers to support ReLU² fused with FP4 block quantization.
CuTe‑DSL Public Surface
flashinfer/fused_moe/cute_dsl/blackwell_sm12x/__init__.py
Exported MoEMicroKernel in package __all__.
Routing Compaction Kernel
flashinfer/fused_moe/cute_dsl/blackwell_sm12x/triton_compact.py
New Triton kernel compact_topk_ids to compact flattened top‑k IDs into dense local indices and produce active_expert_count / weight_expert_ids mapping.
Dispatch & Micro Backend
flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py
Added micro (tiny‑decode) backend with micro kernel compile/cache path, routing compaction pre‑pass, activation propagation, MAC override ladder, and workspace field for compacted IDs.
Static Kernel
flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_static_kernel.py
Added activation param; storage, pipelines, FC1 tiling, shared‑memory layout, and fused activation/quant logic made conditional on gated (SiLU) vs non‑gated (ReLU²).
Dynamic Kernel
flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.py
Added activation param and self.is_gated; adjusted producer/consumer tiling, pipelines, and fused activation computation for gated vs non‑gated paths.
SM12x Functional API & Wrapper
flashinfer/fused_moe/cute_dsl/b12x_moe.py, flashinfer/fused_moe/cute_dsl/__init__.py, flashinfer/fused_moe/__init__.py, flashinfer/__init__.py
New b12x_fused_moe functional API and B12xMoEWrapper; enforce CUDA13+, activation propagation, backend selection, CUDA‑graph buffer preallocation, and public re‑exports.
CuteDSL NVFP4 Path & Wrapper Simplification
flashinfer/fused_moe/cute_dsl/fused_moe.py
Removed SM12x special‑case paths from NVFP4 flow and reduced supported compute capability decorators to SM100/103; unified NVFP4 runner/autotune path.
Tests
tests/moe/test_cute_dsl_fused_moe.py, tests/moe/test_b12x_fused_moe.py
Adjusted SM‑family gating and weight prep; added comprehensive SM12x b12x test suite (functional, wrapper, micro‑kernel, CUDA graph) including ReLU² coverage and numerical checks.

Sequence Diagram(s)

sequenceDiagram
    participant Caller as Caller
    participant Dispatch as Dispatch Layer
    participant Compact as Triton Compact
    participant Cache as Kernel Cache
    participant Micro as Micro Kernel
    participant Static as Static Kernel
    participant Dynamic as Dynamic Kernel
    participant Activation as Activation Func

    Caller->>Dispatch: submit token_selected_experts, weights, scales, activation
    Dispatch->>Dispatch: compute routed_rows = num_tokens * top_k
    Dispatch->>Dispatch: select backend (micro/static/dynamic) using cutovers

    alt Micro Path
        Dispatch->>Compact: compact_topk_ids(topk_ids)
        Compact-->>Dispatch: compact_ids, active_expert_count, weight_expert_ids
        Dispatch->>Cache: lookup/compile micro kernel (activation, mac_override)
        Cache-->>Dispatch: micro kernel
        Dispatch->>Micro: launch(compact_ids, activation, weights, scales)
        Micro->>Activation: apply SiLU or ReLU²
        Micro->>Caller: write outputs
    else Static Path
        Dispatch->>Cache: lookup/compile static kernel (activation)
        Cache-->>Dispatch: static kernel
        Dispatch->>Static: launch(topk_ids, activation, weights, scales)
        Static->>Activation: apply SiLU or ReLU²
        Static->>Caller: write outputs
    else Dynamic Path
        Dispatch->>Dynamic: launch dynamic kernel (activation)
        Dynamic->>Activation: apply activation variant
        Dynamic->>Caller: write outputs
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

Possibly related PRs

Suggested labels

run-ci

Suggested reviewers

  • yzh119
  • cyx-6
  • jiahanc
  • IwakuraRein
  • samuellees
  • jimmyzho
  • aleozlx

Poem

🐰 Hoppity-hop, experts line the track,
SiLU gates wiggle, ReLU² leaps back.
Tiny kernels hustle, compaction sings,
Quantized carrots fuel faster things.
A rabbit cheers: kernels, hop—great spring!

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately summarizes the main feature additions: new SM120 MoE APIs (b12x_fused_moe/B12xMoEWrapper) with micro kernel and ReLU2 activation support.
Description check ✅ Passed The description provides comprehensive detail: clear summary of changes, API separation table, micro kernel explanation, ReLU2 activation details, usage examples, and related issues. All required sections are substantially filled out.
Docstring Coverage ✅ Passed Docstring coverage is 81.48% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for non-gated MoE activations (ReLU2) and introduces a specialized micro-kernel for small decode batches on Blackwell SM120/SM121 architectures. The changes include updates to the static and dynamic CuTe DSL kernels, a new Triton-based ID compaction pre-pass, and updated benchmarking and testing utilities. Feedback indicates that the "moe_micro_kernel.py" file is missing from the PR and identifies a potential out-of-bounds risk when slicing the workspace buffer for compact IDs.

Comment thread flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py
from .triton_compact import compact_topk_ids as _triton_compact_topk_ids

# Run Triton pre-pass to compact global expert IDs to dense local indices
compact_ids = workspace.compact_topk_ids[: flat_ids.numel()]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a potential out-of-bounds risk if flat_ids.numel() exceeds state_E. While the micro kernel path is currently restricted to routed_rows <= 40, the workspace allocation for compact_topk_ids uses state_E. It would be safer to ensure that the slice does not exceed the allocated size of the workspace buffer, or add an explicit check.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch on the defensive coding. In practice this can't overflow — compact_topk_ids is sized [state_E] (typically 256-512) while flat_ids.numel() is at most 40 on the micro path (the cutover threshold). But the invariant should be explicit. Added an assertion in the next commit

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
benchmarks/routines/moe.py (1)

1790-1813: ⚠️ Potential issue | 🟡 Minor

Pass is_gated into the FP8 bandwidth model too.

TFLOPS now distinguishes gated vs non-gated activations, but both FP8 bandwidth calls still rely on the default path. ReLU2 runs will therefore report inconsistent bandwidth.

🛠️ Suggested fix
     tb_per_sec = calculate_moe_kernel_bandwidth(
         num_tokens,
         hidden_size,
         intermediate_size,
         num_experts,
         top_k,
         median_time,
         input_dtype,
         weight_dtype,
         input_format="fp8",
         weight_format="fp8",
         routing_logits_dtype=routing_logits.dtype,
         active_experts=int(selected_experts.unique().numel()),
         verbose=args.verbose,
+        is_gated=args.activation_type in (ActivationType.Swiglu, ActivationType.Geglu),
     )
     tb_per_sec = calculate_moe_kernel_bandwidth(
         num_tokens,
         hidden_size,
         intermediate_size,
         num_experts,
         top_k,
         median_time,
         input_dtype,
         weight_dtype,
         input_format="fp8",
         weight_format="fp8",
         routing_logits_dtype=routing_logits.dtype,
         active_experts=int(selected_experts.unique().numel()),
         verbose=args.verbose,
+        is_gated=args.activation_type in (ActivationType.Swiglu, ActivationType.Geglu),
     )

Also applies to: 2025-2048

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/routines/moe.py` around lines 1790 - 1813, The FP8 bandwidth call
is missing the is_gated flag, causing gated vs non-gated activations to report
inconsistent bandwidth; update the calculate_moe_kernel_bandwidth invocation(s)
to pass the same is_gated boolean used for calculate_moe_tflops (e.g.,
is_gated=args.activation_type in (ActivationType.Swiglu, ActivationType.Geglu))
so both calculate_moe_tflops(...) and calculate_moe_kernel_bandwidth(...)
receive the same gating hint (apply the same change to the other occurrences
around the later block that mirrors this code).
flashinfer/fused_moe/cute_dsl/fused_moe.py (1)

362-394: ⚠️ Potential issue | 🟠 Major

Reject relu2 outside the SM120/SM121 path.

These new public parameters are only honored in the SM120 branch. The fallback path still goes through _moe_core_impl(), which hard-wires the SwiGLU fusion helper, so activation_type="relu2" on SM100/SM103 can run the wrong math or hit mismatched FC1 shapes instead of failing fast.

🛠️ Suggested guard
@@
-        self.activation_type = activation_type
+        if activation_type not in {"silu", "relu2"}:
+            raise ValueError(f"Unsupported activation_type: {activation_type!r}")
+        self.activation_type = activation_type
@@
         major, minor = torch.cuda.get_device_capability(device)
         self._is_sm120 = major == 12
+        if activation_type != "silu" and not self._is_sm120:
+            raise ValueError(
+                "activation_type='relu2' is only supported on SM120/SM121"
+            )
 def cute_dsl_fused_moe_nvfp4(
@@
-    if num_local_experts is None:
+    if activation_type not in {"silu", "relu2"}:
+        raise ValueError(f"Unsupported activation_type: {activation_type!r}")
+
+    if num_local_experts is None:
         num_local_experts = num_experts
@@
     major, _ = torch.cuda.get_device_capability(x.device)
     if major == 12:
         ...
+    elif activation_type != "silu":
+        raise ValueError(
+            "activation_type='relu2' is only supported on SM120/SM121"
+        )

Also applies to: 827-916

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/cute_dsl/fused_moe.py` around lines 362 - 394, The
constructor accepts activation_type but the non-SM120/SM121 path still calls
_moe_core_impl which assumes SwiGLU; add a runtime guard in the initializer (or
immediately before dispatch to _moe_core_impl) that checks activation_type and
the detected GPU SM version and either raise a clear error or restrict allowed
values when SM < 120 (e.g., if activation_type == "relu2" and not on SM120/121,
raise ValueError). Update the dispatch code path that calls _moe_core_impl (and
any fallback branches referenced around the alternate implementation) to enforce
this same check so relu2 is only honored on the SM120/SM121 branch and cannot
silently run with the SwiGLU helper.
flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py (1)

138-185: ⚠️ Potential issue | 🔴 Critical

Size compact_topk_ids for routed rows, not experts.

Line 824 slices this buffer to flat_ids.numel() (num_tokens * top_k), but the workspace only allocates state_E entries. Any micro launch with more routed rows than local experts will write past the end of the buffer.

🛠️ Suggested fix
-    compact_topk_ids: torch.Tensor  # [state_E] int32, for micro kernel pre-pass
+    compact_topk_ids: torch.Tensor  # [max_rows] int32, for micro kernel pre-pass
@@
-        compact_topk_ids=torch.empty(state_E, dtype=torch.int32, device=device),
+        compact_topk_ids=torch.empty(max_rows, dtype=torch.int32, device=device),
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py` around lines
138 - 185, compact_topk_ids is currently allocated with length state_E but is
later sliced to flat_ids.numel() (num_tokens * top_k) in the micro-kernel
pre-pass, causing out-of-bounds writes when num_tokens > state_E; in
allocate_sm120_static_workspace change the compact_topk_ids allocation in
Sm120StaticMoEWorkspace to have capacity for the worst-case routed rows times
top-k (e.g. torch.empty(state_E * max_rows * num_topk, dtype=torch.int32,
device=device) or at minimum torch.empty(max_rows * state_E * num_topk, ...)) so
flat_ids.numel() can always fit, and keep references to compact_topk_ids,
allocate_sm120_static_workspace, Sm120StaticMoEWorkspace, num_topk, max_rows,
and state_E to locate the change.
🧹 Nitpick comments (1)
flashinfer/fused_moe/cute_dsl/blackwell_sm12x/triton_compact.py (1)

76-86: Add a defensive size guard for micro-only usage.

This kernel is O(BLOCK²) in a single program; adding an explicit upper bound makes accidental large launches fail fast with a clear message.

Possible guardrail
     block = triton.next_power_of_2(total_pairs)
+    if block > 256:
+        raise ValueError(
+            f"compact_topk_ids is intended for micro batches; got total_pairs={total_pairs}"
+        )
     num_warps = 1 if block <= 16 else 2
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/triton_compact.py` around lines
76 - 86, Add a defensive size guard before launching _compact_topk_ids_kernel to
prevent accidental large O(BLOCK²) launches: compute block =
triton.next_power_of_2(total_pairs) as you do, then check against a small hard
limit (e.g. MAX_BLOCK = 64 or 128) and/or a MAX_PAIRS derived limit and raise a
clear RuntimeError if block > MAX_BLOCK (include block and total_pairs in the
message). Keep the existing num_warps logic and kernel args
(_compact_topk_ids_kernel, topk_ids, compact_topk_ids, weight_expert_ids,
active_expert_count, total_pairs, BLOCK=block, num_warps=num_warps) unchanged;
just insert the guard using the same symbols so oversized launches fail fast
with a descriptive error.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@benchmarks/routines/moe.py`:
- Around line 1367-1371: The code silently maps unsupported ActivationType
values to "silu"; change the logic to validate args.activation_type against the
supported mapping instead of defaulting. Use the _ACT_STR dict to look up
activation_str and if the activation_type is not present raise a clear exception
(e.g., ValueError) mentioning the unsupported ActivationType and listing
supported keys; also compute is_gated from ActivationType.Geglu and
ActivationType.Swiglu as before but ensure Geglu is rejected if not in _ACT_STR
so it cannot silently run the SiLU kernel.

In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.py`:
- Around line 539-541: Comments describing w13 tile ordering conflict with
actual code: update the inline comments around the w13 TMA descriptor creation
(the lines that mention "Gate tiles at N=..." and "up tiles at N=...") to
reflect the actual ordering used by the code path where gate_slice_idx =
intermediate_slice + gate_tile_cnt (i.e., up tiles occupy the first half of
N-tiles and gate tiles the second half). Locate the call to
self._dense_cls._make_tma_atoms_and_tensors that produces tma_b_w13/gB_w13 and
any other similar comments (also around the other occurrence near lines
~1168-1171) and change the wording so it states "Up tiles at N=0..I_tp/tile_N-1,
Gate tiles at N=I_tp/tile_N..2*I_tp/tile_N-1" or equivalent that matches the
gate_slice_idx logic.

In `@tests/moe/test_cute_dsl_fused_moe.py`:
- Around line 1093-1096: Replace the custom skip gating that defines
sm120_cuda13 (which currently uses is_sm120_family() and _has_cuda_13()) with
the repository-standard capability checks from flashinfer.utils or the API
capability method; specifically, remove is_sm120_family()/_has_cuda_13() and use
the appropriate flashinfer.utils helper (e.g., is_sm120_supported() or analogous
is_sm90a_supported()/is_sm100a_supported()) or call
api_name.is_compute_capability_supported(cc) to decide the skip. Ensure the new
marker still uses pytest.mark.skipif(...) with a descriptive reason string
indicating the required SM/CUDA capability.

---

Outside diff comments:
In `@benchmarks/routines/moe.py`:
- Around line 1790-1813: The FP8 bandwidth call is missing the is_gated flag,
causing gated vs non-gated activations to report inconsistent bandwidth; update
the calculate_moe_kernel_bandwidth invocation(s) to pass the same is_gated
boolean used for calculate_moe_tflops (e.g., is_gated=args.activation_type in
(ActivationType.Swiglu, ActivationType.Geglu)) so both calculate_moe_tflops(...)
and calculate_moe_kernel_bandwidth(...) receive the same gating hint (apply the
same change to the other occurrences around the later block that mirrors this
code).

In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py`:
- Around line 138-185: compact_topk_ids is currently allocated with length
state_E but is later sliced to flat_ids.numel() (num_tokens * top_k) in the
micro-kernel pre-pass, causing out-of-bounds writes when num_tokens > state_E;
in allocate_sm120_static_workspace change the compact_topk_ids allocation in
Sm120StaticMoEWorkspace to have capacity for the worst-case routed rows times
top-k (e.g. torch.empty(state_E * max_rows * num_topk, dtype=torch.int32,
device=device) or at minimum torch.empty(max_rows * state_E * num_topk, ...)) so
flat_ids.numel() can always fit, and keep references to compact_topk_ids,
allocate_sm120_static_workspace, Sm120StaticMoEWorkspace, num_topk, max_rows,
and state_E to locate the change.

In `@flashinfer/fused_moe/cute_dsl/fused_moe.py`:
- Around line 362-394: The constructor accepts activation_type but the
non-SM120/SM121 path still calls _moe_core_impl which assumes SwiGLU; add a
runtime guard in the initializer (or immediately before dispatch to
_moe_core_impl) that checks activation_type and the detected GPU SM version and
either raise a clear error or restrict allowed values when SM < 120 (e.g., if
activation_type == "relu2" and not on SM120/121, raise ValueError). Update the
dispatch code path that calls _moe_core_impl (and any fallback branches
referenced around the alternate implementation) to enforce this same check so
relu2 is only honored on the SM120/SM121 branch and cannot silently run with the
SwiGLU helper.

---

Nitpick comments:
In `@flashinfer/fused_moe/cute_dsl/blackwell_sm12x/triton_compact.py`:
- Around line 76-86: Add a defensive size guard before launching
_compact_topk_ids_kernel to prevent accidental large O(BLOCK²) launches: compute
block = triton.next_power_of_2(total_pairs) as you do, then check against a
small hard limit (e.g. MAX_BLOCK = 64 or 128) and/or a MAX_PAIRS derived limit
and raise a clear RuntimeError if block > MAX_BLOCK (include block and
total_pairs in the message). Keep the existing num_warps logic and kernel args
(_compact_topk_ids_kernel, topk_ids, compact_topk_ids, weight_expert_ids,
active_expert_count, total_pairs, BLOCK=block, num_warps=num_warps) unchanged;
just insert the guard using the same symbols so oversized launches fail fast
with a descriptive error.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 5e3f31a2-4325-47e8-9db8-0acd0b10bef6

📥 Commits

Reviewing files that changed from the base of the PR and between 25b324d and ea4ad45.

📒 Files selected for processing (11)
  • benchmarks/routines/moe.py
  • benchmarks/routines/moe_utils.py
  • flashinfer/cute_dsl/fp4_common.py
  • flashinfer/fused_moe/cute_dsl/blackwell_sm12x/__init__.py
  • flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py
  • flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.py
  • flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_micro_kernel.py
  • flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_static_kernel.py
  • flashinfer/fused_moe/cute_dsl/blackwell_sm12x/triton_compact.py
  • flashinfer/fused_moe/cute_dsl/fused_moe.py
  • tests/moe/test_cute_dsl_fused_moe.py

Comment thread benchmarks/routines/moe.py Outdated
Comment thread flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.py
Comment thread tests/moe/test_cute_dsl_fused_moe.py Outdated
@bkryu bkryu added the v0.6.9 release blocker label for 0.6.9 label Apr 15, 2026
@bkryu

bkryu commented Apr 15, 2026

Copy link
Copy Markdown
Collaborator Author

/bot run

@flashinfer-bot

Copy link
Copy Markdown
Collaborator

GitLab MR !556 has been created, and the CI pipeline #48636109 is currently running. I'll report back once the pipeline job completes.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@benchmarks/routines/moe.py`:
- Around line 2030-2037: The bandwidth computation is still using the default
gated accounting while TFLOPS now selects non-gated accounting for some
activations; update the call to calculate_moe_kernel_bandwidth to pass the same
is_gated boolean used for calculate_moe_tflops (i.e.,
is_gated=args.activation_type in (ActivationType.Swiglu, ActivationType.Geglu)
or the expression that produced non-gated for Relu2) so both TFLOPS and kernel
bandwidth use the same activation-aware gating flag (refer to
calculate_moe_tflops and calculate_moe_kernel_bandwidth to locate the calls).
- Around line 1795-1802: The TFLOPS call uses args.activation_type to set
is_gated but this routine still constructs gated FC1 tensors and
run_fp8_block_moe never receives an activation flag, so reported TFLOPS can
diverge from the executed kernel; fix by not switching the is_gated flag based
on args.activation_type here — either hard-code is_gated=True (gated-only path)
or derive is_gated from the same gated-only indicator used when building tensors
(e.g., the 2 * intermediate_size gated FC1 logic) and/or update
run_fp8_block_moe to accept and forward an activation_type so activation-based
toggles are consistent with calculate_moe_tflops; reference
calculate_moe_tflops, run_fp8_block_moe, args.activation_type and the gated FC1
construction (2 * intermediate_size) when making the change.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 43a02c6d-a59b-416a-b805-d1e691a5397e

📥 Commits

Reviewing files that changed from the base of the PR and between ea4ad45 and 58b1168.

📒 Files selected for processing (6)
  • benchmarks/routines/moe.py
  • flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py
  • flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dynamic_kernel.py
  • flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_micro_kernel.py
  • flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_static_kernel.py
  • tests/moe/test_cute_dsl_fused_moe.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • tests/moe/test_cute_dsl_fused_moe.py
  • flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py

Comment thread benchmarks/routines/moe.py
Comment thread benchmarks/routines/moe.py
@bkryu bkryu changed the title feat: Add micro kernel + ReLU2 activation for SM120 b12x fused MoE feat: Add b12x_fused_moe / B12xMoEWrapper SM120 APIs with micro kernel and ReLU2 Apr 16, 2026
@bkryu

bkryu commented Apr 16, 2026

Copy link
Copy Markdown
Collaborator Author

/bot stop

@bkryu

bkryu commented Apr 16, 2026

Copy link
Copy Markdown
Collaborator Author

/bot run

@flashinfer-bot

Copy link
Copy Markdown
Collaborator

The GitLab CI pipeline #48713073 has been cancelled.

@flashinfer-bot

Copy link
Copy Markdown
Collaborator

GitLab MR !556 has been updated with latest changes, and the CI pipeline #48714380 is currently running. I'll report back once the pipeline job completes.

@bkryu

bkryu commented Apr 16, 2026

Copy link
Copy Markdown
Collaborator Author

/bot stop

@flashinfer-bot

Copy link
Copy Markdown
Collaborator

The GitLab CI pipeline #48714380 has been cancelled.

@bkryu

bkryu commented Apr 16, 2026

Copy link
Copy Markdown
Collaborator Author

/bot run

@flashinfer-bot

Copy link
Copy Markdown
Collaborator

GitLab MR !556 has been updated with latest changes, and the CI pipeline #48739189 is currently running. I'll report back once the pipeline job completes.

@@ -0,0 +1,1334 @@
"""
Copyright (c) 2025 by FlashInfer team.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2026?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will try to capture in a subsequent PR that also addresses non-blocking mm_fp4(backend='b12x') comments to avoid needing to re-trigger the CI again 😅

@aleozlx

aleozlx commented Apr 18, 2026

Copy link
Copy Markdown
Collaborator

@flashinfer-bot run

@bkryu

bkryu commented Apr 20, 2026

Copy link
Copy Markdown
Collaborator Author

/bot run

@flashinfer-bot

Copy link
Copy Markdown
Collaborator

GitLab MR !556 has been updated with latest changes, and the CI pipeline #49005075 is currently running. I'll report back once the pipeline job completes.

@nv-yunzheq nv-yunzheq left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@bkryu bkryu merged commit 8a9970b into flashinfer-ai:main Apr 20, 2026
70 of 84 checks passed
@flashinfer-bot

Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #49005075: 15/20 passed

aleozlx pushed a commit that referenced this pull request Apr 24, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

Follow-up to #3051 (`backend="b12x"` for `mm_fp4` on SM120) and #3080
(`b12x_fused_moe` / `B12xMoEWrapper` SM120 APIs) addressing four
reviewer comments that landed after merge. No public API changes; no
kernel behavior changes.

- **Copyright**: bump `tests/moe/test_b12x_fused_moe.py` to 2026.
- **Benchmark split**: new `b12x_fused_moe` routine (SM120/121, BF16
input, SwiGLU + ReLU²); `cute_dsl_fp4_block_scale_moe` is now
SM100/103-only. Aligns with the `B12xMoEWrapper` / `CuteDslMoEWrapper`
Python API split.
- **Cache SM count**: replace a hot-path
`torch.cuda.get_device_properties(...).multi_processor_count` in the
`b12x` FP4 GEMM runner with the cached `get_device_sm_count()` helper.
- **Rename for provenance**: `dense_blockscaled_gemm_sm120.py` →
`dense_blockscaled_gemm_sm120_b12x.py` and
`Sm120BlockScaledDenseGemmKernel` →
`Sm120B12xBlockScaledDenseGemmKernel` (via `git mv`, 6 import sites
updated). `backend="b12x"` string unchanged.


## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **New Features**
* Added new `b12x_fused_moe` benchmark routine for NVFP4 MoE inference
with support for both SwiGLU and ReLU2 activation types.
* Extended Blackwell architecture support with updated kernel
implementations.

* **Documentation**
* Updated benchmark samples with new `b12x_fused_moe` test
configurations.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
@coderabbitai coderabbitai Bot mentioned this pull request May 15, 2026
4 tasks
@bkryu bkryu deleted the b12x_micro_kernel branch June 8, 2026 17:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

op: moe run-ci v0.6.9 release blocker label for 0.6.9

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants