Skip to content

Enable B12x backend for non-gated MoEs (like Nemotron) #41244

Closed
askliar wants to merge 16 commits into
vllm-project:mainfrom
askliar:askliar/b12x-wrapper-on-pr40082
Closed

Enable B12x backend for non-gated MoEs (like Nemotron) #41244
askliar wants to merge 16 commits into
vllm-project:mainfrom
askliar:askliar/b12x-wrapper-on-pr40082

Conversation

@askliar

@askliar askliar commented Apr 29, 2026

Copy link
Copy Markdown
Contributor

Summary

Stacked on top of #40082.

This PR refines the FlashInfer B12x MoE integration by switching the SM12x MoE path to FlashInfer's B12xMoEWrapper API and adding ReLU2 / non-gated MoE coverage.

Key changes:

  • Use B12xMoEWrapper for FlashInferB12xExperts
  • Keep BF16 hidden states as unquantized inputs; B12x handles FP4 activation quantization internally
  • Support both SiLU gated MoE and ReLU2 non-gated MoE
  • Add ReLU2 test coverage
  • Update the test helper to allow non-default activation and is_act_and_mul

Duplicate-work check

This is intentionally not a duplicate of #40082: it is an incremental stacked change on top of #40082.

meena-at-work and others added 15 commits April 27, 2026 21:37
Adds FlashInferCuteDSLSM12xExperts targeting SM120/SM121 (RTX Pro
6000 / DGX Spark) using cute_dsl_fused_moe_nvfp4 from FlashInfer
PRs vllm-project#3051 and vllm-project#3066. The kernel fuses token dispatch, W1 GEMM, SwiGLU,
and W2 GEMM into a single call; BF16 hidden states are passed directly
as activation quantization is fused internally.

- vllm/utils/flashinfer.py: lazy import wrappers for
  cute_dsl_fused_moe_nvfp4 and convert_sf_to_mma_layout; adds
  has_flashinfer_cutedsl_sm12x_moe() availability probe
- experts/flashinfer_cutedsl_moe.py: FlashInferCuteDSLSM12xExperts
  with TODO to adopt plan/run() API from PR vllm-project#3066
- oracle/nvfp4.py: FLASHINFER_CUTEDSL_SM12X backend enum and routing;
  falls back to FLASHINFER_CUTLASS on SM12x when PRs are absent
- flashinfer_fp4_moe.py: SM12X added to FI weight-prep path and
  w1/w3 → w3/w1 reorder list
- tests/kernels/moe/test_cutedsl_sm12x_moe.py: correctness tests vs
  BF16 torch reference; module-level skip when SM120 hw or FlashInfer
  PRs are absent

Signed-off-by: Meenakshi Venkataraman <meenakshiv@nvidia.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Integrates FlashInfer PR vllm-project#3051 b12x dense GEMM backend into the NVFP4
linear layer path. b12x uses CuTe DSL warp-level MMA with adaptive tile
sizing to improve SM utilization on small-M decode shapes.

Changes:
- has_flashinfer_b12x_gemm(): availability check via Sm120BlockScaledDenseGemmKernel
- FlashInferB12xNvFp4LinearKernel: new NvFp4LinearKernel subclass
- Auto-selects b12x on SM120/SM121 (has_device_capability(120)), falls
  back to FLASHINFER_CUTLASS when unavailable
- Adds "flashinfer-b12x" to VLLM_NVFP4_GEMM_BACKEND valid choices
- b12x test cases in test_flashinfer_nvfp4_scaled_mm.py

Measured on DGX Spark (SM121, Qwen3-30B-A3B-NVFP4, same MoE backend):
  b12x:               71.81 out tok/s (1P), 229.24 (8P)
  flashinfer-cutlass:  70.52 out tok/s (1P), 216.28 (8P)
  (+1.8% 1P, +6.0% 8P)

Signed-off-by: Meenakshi Venkataraman <meenakshiv@nvidia.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Preserve a2_gscale; pass torch.ones_like(a2_gscale) to kernel instead
  of fill_(1.0) which destroyed the calibrated value in-place
- Precompute w1_sf_mma/w2_sf_mma in process_weights_after_loading
  instead of converting on every forward pass
- Fix x_sf_placeholder dtype: float8_e4m3fn (was bfloat16)
- Pass topk_weights.float() for float32 routing weights as kernel expects

Signed-off-by: Meenakshi Venkataraman <meenakshiv@nvidia.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Align user-visible backend name with FlashInfer PR vllm-project#3080 which establishes
b12x as the canonical namespace for SM12x MoE kernels. Renames:
  - MoEBackend Literal: "flashinfer_cutedsl_sm12x" -> "flashinfer_b12x"
  - NvFp4MoeBackend enum: FLASHINFER_CUTEDSL_SM12X -> FLASHINFER_B12X
  - vLLM helper: has_flashinfer_cutedsl_sm12x_moe -> has_flashinfer_b12x_moe

Signed-off-by: Meenakshi Venkataraman <meenakshiv@nvidia.com>
…lashInfer PR vllm-project#3080)

FlashInfer PR vllm-project#3080 split SM100 and SM120/121 MoE APIs:
  - cute_dsl_fused_moe_nvfp4 is now SM100-only
  - b12x_fused_moe is the new SM120/121 entry point

Update FlashInferCuteDSLSM12xExperts to call b12x_fused_moe instead of
cute_dsl_fused_moe_nvfp4. The new API drops x_sf and local_expert_offset,
uses output= instead of moe_output=, and takes topk_weights directly
(no .float() cast). Also simplify process_weights_after_loading to
fill a2_gscale in-place with 1.0 rather than creating a separate ones
tensor.

Add flashinfer_b12x_fused_moe lazy wrapper and update
has_flashinfer_b12x_moe() to check b12x_fused_moe.

Signed-off-by: Meenakshi Venkataraman <meenakshiv@nvidia.com>
Signed-off-by: Meenakshi Venkataraman <meenakshiv@nvidia.com>
Signed-off-by: Meenakshi Venkataraman <meenakshiv@nvidia.com>
…12xExperts

- Remove get_cute_dtype, flashinfer_cutedsl_moe_masked, and associated
  unused imports from flashinfer_cutedsl_moe.py — copy-pasted from
  flashinfer_cutedsl_batched_moe.py during an earlier draft that used
  the masked kernel path; the final SM12x implementation calls
  b12x_fused_moe directly so they were never used.
- Rename FlashInferCuteDSLSM12xExperts -> FlashInferB12xExperts to align
  with the FlashInferB12xNvFp4LinearKernel naming convention.

Signed-off-by: Meenakshi Venkataraman <meenakshiv@nvidia.com>
Consistent with the one-class-per-file convention in fused_moe/experts/.
flashinfer_cutedsl_moe.py now contains only FlashInferCuteDSLExperts (SM100).
Also renames the test to test_flashinfer_b12x_moe.py to match.

Signed-off-by: Meenakshi Venkataraman <meenakshiv@nvidia.com>
…12x_moe

process_weights_after_loading computes MMA-layout block scales after
normalizing. The test was constructing FlashInferB12xExperts without
setting these, so the kernel received uninitialized scale buffers. Fix:
compute w1_sf_mma / w2_sf_mma via flashinfer_convert_sf_to_mma_layout
directly from the test's pre-normalized scales before passing the experts
object to FusedMoEKernel.

Signed-off-by: Meenakshi Venkataraman <meenakshiv@nvidia.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Meenakshi Venkataraman <meenakshiv@nvidia.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Meenakshi Venkataraman <meenakshiv@nvidia.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
FLASHINFER_CUTEDSL is routed to prepare_nvfp4_moe_layer_for_flashinfer_cutedsl
before the elif that calls prepare_nvfp4_moe_layer_for_fi_or_cutlass, so it
can never reach this assert. Added inadvertently in the b12x commit.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Meenakshi Venkataraman <meenakshiv@nvidia.com>
CUTLASS SM121 MMA op guard causes errors loading NVFP4 models on DGX Spark
when b12x is auto-selected. Remove FLASHINFER_B12X from AVAILABLE_BACKENDS
in the MoE oracle and FlashInferB12xNvFp4LinearKernel from _POSSIBLE_NVFP4_KERNELS.
Both remain reachable via explicit opt-in (moe_backend="flashinfer_b12x" and
VLLM_NVFP4_GEMM_BACKEND=flashinfer-b12x).

Also restore pre-existing lazy imports in utils/flashinfer.py to their
original positions and module paths; only flashinfer_b12x_fused_moe is new.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Meenakshi Venkataraman <meenakshiv@nvidia.com>

@claude claude Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for FlashInfer B12x kernels (SM120+) for NVFP4 quantized linear layers and fused MoE, including the implementation of FlashInferB12xNvFp4LinearKernel and FlashInferB12xExperts. It also updates kernel configurations and adds unit tests. Feedback suggests removing redundant assertions in the MoE apply method to avoid performance overhead in the hot path.

Comment on lines +197 to +205
assert self.w1_scale is not None and self.w2_scale is not None, (
"w1_scale and w2_scale must not be None for FlashInferB12xExperts"
)
assert self.g1_alphas is not None and self.g2_alphas is not None, (
"g1_alphas and g2_alphas must not be None for FlashInferB12xExperts"
)
assert self.a2_gscale is not None, (
"a2_gscale must not be None for FlashInferB12xExperts"
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The assertion checks for w1_scale, w2_scale, g1_alphas, g2_alphas, and a2_gscale are redundant because these attributes are expected to be initialized in the base class or during the __init__ phase. If they are missing, the kernel will fail later anyway. More importantly, these assertions are executed on every apply call, which is in the hot path and can impact performance.

@askliar askliar changed the title Enable B12x backend for non-gated MoEs (like Nemotron) WIP: Enable B12x backend for non-gated MoEs (like Nemotron) Apr 29, 2026
@askliar askliar closed this May 21, 2026
@github-project-automation github-project-automation Bot moved this to Done in NVIDIA May 21, 2026
@mergify mergify Bot added ci/build deepseek Related to DeepSeek models frontend llama Related to Llama models multi-modality Related to multi-modality (#4194) mistral Related to Mistral models new-model Requests to new models performance Performance-related issues qwen Related to Qwen models gpt-oss Related to GPT-OSS models rocm Related to AMD ROCm intel-gpu Related to Intel GPU cpu Related to CPU backends structured-output speculative-decoding v1 labels May 21, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build cpu Related to CPU backends deepseek Related to DeepSeek models frontend gpt-oss Related to GPT-OSS models intel-gpu Related to Intel GPU kv-connector llama Related to Llama models mistral Related to Mistral models multi-modality Related to multi-modality (#4194) new-model Requests to new models nvidia performance Performance-related issues qwen Related to Qwen models rocm Related to AMD ROCm speculative-decoding structured-output tool-calling v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

4 participants