Enable B12x backend for non-gated MoEs (like Nemotron) #41244
Conversation
Adds FlashInferCuteDSLSM12xExperts targeting SM120/SM121 (RTX Pro 6000 / DGX Spark) using cute_dsl_fused_moe_nvfp4 from FlashInfer PRs vllm-project#3051 and vllm-project#3066. The kernel fuses token dispatch, W1 GEMM, SwiGLU, and W2 GEMM into a single call; BF16 hidden states are passed directly as activation quantization is fused internally. - vllm/utils/flashinfer.py: lazy import wrappers for cute_dsl_fused_moe_nvfp4 and convert_sf_to_mma_layout; adds has_flashinfer_cutedsl_sm12x_moe() availability probe - experts/flashinfer_cutedsl_moe.py: FlashInferCuteDSLSM12xExperts with TODO to adopt plan/run() API from PR vllm-project#3066 - oracle/nvfp4.py: FLASHINFER_CUTEDSL_SM12X backend enum and routing; falls back to FLASHINFER_CUTLASS on SM12x when PRs are absent - flashinfer_fp4_moe.py: SM12X added to FI weight-prep path and w1/w3 → w3/w1 reorder list - tests/kernels/moe/test_cutedsl_sm12x_moe.py: correctness tests vs BF16 torch reference; module-level skip when SM120 hw or FlashInfer PRs are absent Signed-off-by: Meenakshi Venkataraman <meenakshiv@nvidia.com> Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Integrates FlashInfer PR vllm-project#3051 b12x dense GEMM backend into the NVFP4 linear layer path. b12x uses CuTe DSL warp-level MMA with adaptive tile sizing to improve SM utilization on small-M decode shapes. Changes: - has_flashinfer_b12x_gemm(): availability check via Sm120BlockScaledDenseGemmKernel - FlashInferB12xNvFp4LinearKernel: new NvFp4LinearKernel subclass - Auto-selects b12x on SM120/SM121 (has_device_capability(120)), falls back to FLASHINFER_CUTLASS when unavailable - Adds "flashinfer-b12x" to VLLM_NVFP4_GEMM_BACKEND valid choices - b12x test cases in test_flashinfer_nvfp4_scaled_mm.py Measured on DGX Spark (SM121, Qwen3-30B-A3B-NVFP4, same MoE backend): b12x: 71.81 out tok/s (1P), 229.24 (8P) flashinfer-cutlass: 70.52 out tok/s (1P), 216.28 (8P) (+1.8% 1P, +6.0% 8P) Signed-off-by: Meenakshi Venkataraman <meenakshiv@nvidia.com> Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Preserve a2_gscale; pass torch.ones_like(a2_gscale) to kernel instead of fill_(1.0) which destroyed the calibrated value in-place - Precompute w1_sf_mma/w2_sf_mma in process_weights_after_loading instead of converting on every forward pass - Fix x_sf_placeholder dtype: float8_e4m3fn (was bfloat16) - Pass topk_weights.float() for float32 routing weights as kernel expects Signed-off-by: Meenakshi Venkataraman <meenakshiv@nvidia.com> Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Align user-visible backend name with FlashInfer PR vllm-project#3080 which establishes b12x as the canonical namespace for SM12x MoE kernels. Renames: - MoEBackend Literal: "flashinfer_cutedsl_sm12x" -> "flashinfer_b12x" - NvFp4MoeBackend enum: FLASHINFER_CUTEDSL_SM12X -> FLASHINFER_B12X - vLLM helper: has_flashinfer_cutedsl_sm12x_moe -> has_flashinfer_b12x_moe Signed-off-by: Meenakshi Venkataraman <meenakshiv@nvidia.com>
…lashInfer PR vllm-project#3080) FlashInfer PR vllm-project#3080 split SM100 and SM120/121 MoE APIs: - cute_dsl_fused_moe_nvfp4 is now SM100-only - b12x_fused_moe is the new SM120/121 entry point Update FlashInferCuteDSLSM12xExperts to call b12x_fused_moe instead of cute_dsl_fused_moe_nvfp4. The new API drops x_sf and local_expert_offset, uses output= instead of moe_output=, and takes topk_weights directly (no .float() cast). Also simplify process_weights_after_loading to fill a2_gscale in-place with 1.0 rather than creating a separate ones tensor. Add flashinfer_b12x_fused_moe lazy wrapper and update has_flashinfer_b12x_moe() to check b12x_fused_moe. Signed-off-by: Meenakshi Venkataraman <meenakshiv@nvidia.com>
Signed-off-by: Meenakshi Venkataraman <meenakshiv@nvidia.com>
Signed-off-by: Meenakshi Venkataraman <meenakshiv@nvidia.com>
…12xExperts - Remove get_cute_dtype, flashinfer_cutedsl_moe_masked, and associated unused imports from flashinfer_cutedsl_moe.py — copy-pasted from flashinfer_cutedsl_batched_moe.py during an earlier draft that used the masked kernel path; the final SM12x implementation calls b12x_fused_moe directly so they were never used. - Rename FlashInferCuteDSLSM12xExperts -> FlashInferB12xExperts to align with the FlashInferB12xNvFp4LinearKernel naming convention. Signed-off-by: Meenakshi Venkataraman <meenakshiv@nvidia.com>
Consistent with the one-class-per-file convention in fused_moe/experts/. flashinfer_cutedsl_moe.py now contains only FlashInferCuteDSLExperts (SM100). Also renames the test to test_flashinfer_b12x_moe.py to match. Signed-off-by: Meenakshi Venkataraman <meenakshiv@nvidia.com>
…12x_moe process_weights_after_loading computes MMA-layout block scales after normalizing. The test was constructing FlashInferB12xExperts without setting these, so the kernel received uninitialized scale buffers. Fix: compute w1_sf_mma / w2_sf_mma via flashinfer_convert_sf_to_mma_layout directly from the test's pre-normalized scales before passing the experts object to FusedMoEKernel. Signed-off-by: Meenakshi Venkataraman <meenakshiv@nvidia.com> Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Meenakshi Venkataraman <meenakshiv@nvidia.com> Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Meenakshi Venkataraman <meenakshiv@nvidia.com> Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
FLASHINFER_CUTEDSL is routed to prepare_nvfp4_moe_layer_for_flashinfer_cutedsl before the elif that calls prepare_nvfp4_moe_layer_for_fi_or_cutlass, so it can never reach this assert. Added inadvertently in the b12x commit. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Meenakshi Venkataraman <meenakshiv@nvidia.com>
CUTLASS SM121 MMA op guard causes errors loading NVFP4 models on DGX Spark when b12x is auto-selected. Remove FLASHINFER_B12X from AVAILABLE_BACKENDS in the MoE oracle and FlashInferB12xNvFp4LinearKernel from _POSSIBLE_NVFP4_KERNELS. Both remain reachable via explicit opt-in (moe_backend="flashinfer_b12x" and VLLM_NVFP4_GEMM_BACKEND=flashinfer-b12x). Also restore pre-existing lazy imports in utils/flashinfer.py to their original positions and module paths; only flashinfer_b12x_fused_moe is new. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Meenakshi Venkataraman <meenakshiv@nvidia.com>
There was a problem hiding this comment.
Code Review
This pull request adds support for FlashInfer B12x kernels (SM120+) for NVFP4 quantized linear layers and fused MoE, including the implementation of FlashInferB12xNvFp4LinearKernel and FlashInferB12xExperts. It also updates kernel configurations and adds unit tests. Feedback suggests removing redundant assertions in the MoE apply method to avoid performance overhead in the hot path.
| assert self.w1_scale is not None and self.w2_scale is not None, ( | ||
| "w1_scale and w2_scale must not be None for FlashInferB12xExperts" | ||
| ) | ||
| assert self.g1_alphas is not None and self.g2_alphas is not None, ( | ||
| "g1_alphas and g2_alphas must not be None for FlashInferB12xExperts" | ||
| ) | ||
| assert self.a2_gscale is not None, ( | ||
| "a2_gscale must not be None for FlashInferB12xExperts" | ||
| ) |
There was a problem hiding this comment.
The assertion checks for w1_scale, w2_scale, g1_alphas, g2_alphas, and a2_gscale are redundant because these attributes are expected to be initialized in the base class or during the __init__ phase. If they are missing, the kernel will fail later anyway. More importantly, these assertions are executed on every apply call, which is in the hot path and can impact performance.
Summary
Stacked on top of #40082.
This PR refines the FlashInfer B12x MoE integration by switching the SM12x MoE path to FlashInfer's
B12xMoEWrapperAPI and adding ReLU2 / non-gated MoE coverage.Key changes:
B12xMoEWrapperforFlashInferB12xExpertsis_act_and_mulDuplicate-work check
This is intentionally not a duplicate of #40082: it is an incremental stacked change on top of #40082.