Skip to content

Add dedicated FlashInferCuteDslMoE layer for standard-path FP4 MoE#21339

Merged
ch-wan merged 34 commits intosgl-project:mainfrom
leejnau:integrate_flashinfer_cutedsl_moe
Apr 10, 2026
Merged

Add dedicated FlashInferCuteDslMoE layer for standard-path FP4 MoE#21339
ch-wan merged 34 commits intosgl-project:mainfrom
leejnau:integrate_flashinfer_cutedsl_moe

Conversation

@leejnau
Copy link
Copy Markdown
Collaborator

@leejnau leejnau commented Mar 24, 2026

Motivation

We want to have the option of having a more standard --moe-runner-backend flashinfer_cutedsl backend that is not specified to DeepEP. This PR integrates the Wrapper API exposed here: flashinfer-ai/flashinfer#2398.

Modifications

Server

  • Add flashinfer_cutedsl as a modular moe_runner backend for --moe-runner-backend flashinfer_cutedsl with modelopt_fp4 quantization on the standard path (moe_a2a_backend=none). Implements the moe_runner pattern established by flashinfer_trtllm.py: a CuteDslFp4MoeQuantInfo dataclass carries weights/scales/wrapper, and a @register_fused_func("none", "flashinfer_cutedsl") handles FP4 quantization + CuteDslMoEWrapper.run(), aligning with the MoE refactor roadmap ([Roadmap] MoE Refactor #8715).
  • Route through the generic FusedMoE -> StandardDispatcher -> MoeRunner pipeline instead of a standalone layer class, unblocking future moe_a2a_backend=flashinfer support.
  • Support EP=1 (TP-sharded experts) and EP=TP (partitioned experts with all-reduce) configurations. CuteDSL's moe_sort handles EP with global expert IDs internally, so skip_local_expert_mapping is enabled.
  • Enable FlashInfer autotuning for CuteDSL and switch the autotune warmup from torch.inference_mode() to torch.no_grad(), since CuteDSL lazily allocates persistent CUDA graph buffers during the first forward pass.
  • Add FP4 weight preprocessing: post-quant W1 gate/up interleave and swizzled-to-MMA block-scale conversion during process_weights_after_loading.
  • Add scale resolution logic (_resolve_cutedsl_standard_scales) that derives correct per-expert GEMM alphas from scalarized activation scales, handling EP slicing and multiple checkpoint formats.
  • Preserve the existing DeepEP low-latency masked CuteDSL route (unchanged, already on main).

Tests

  • test/registered/moe/test_cutedsl_moe.py: unit tests for wrapper accuracy vs. PyTorch reference, CUDA graph parity, and EP-sharded all-reduce correctness.
  • test/registered/backends/test_deepseek_v3_fp4_cutedsl_moe.py: end-to-end GPQA accuracy on DeepSeek-V3 FP4 for both EP=1 and EP=4 configurations (nightly, 4 GPU B200).
  • Verify no regressions on existing flashinfer_trtllm and flashinfer_cutlass backends.

Accuracy Tests

server (EP1):

python3 -m sglang.launch_server --model-path nvidia/DeepSeek-R1-0528-NVFP4-v2 --tensor-parallel-size=8 --cuda-graph-max-bs 256 --max-running-requests 256 --mem-fraction-static 0.85 --ep-size 1 --moe-runner-backend flashinfer_cutedsl --quantization modelopt_fp4

server (EP8):

python3 -m sglang.launch_server --model-path nvidia/DeepSeek-R1-0528-NVFP4-v2 --tensor-parallel-size=8 --cuda-graph-max-bs 256 --max-running-requests 256 --mem-fraction-static 0.85 --ep-size 8 --moe-runner-backend flashinfer_cutedsl --quantization modelopt_fp4

client:

python3 -m sglang.test.run_eval --port 30000 --eval-name gpqa --num-examples 198 --max-tokens 128000 --repeat 8 --thinking-mode deepseek-v3

results (EP1):

'scores': ['0.813', '0.818', '0.803', '0.823', '0.773', '0.798', '0.828', '0.763'], 'mean_score': np.float64(0.8023989898989901)

results (EP8):

'scores': ['0.788', '0.818', '0.778', '0.803', '0.793', '0.793', '0.808', '0.808'], 'mean_score': np.float64(0.798611111111111)

Benchmarking and Profiling

Generally this PR outperforms --moe-runner flashinfer_cutlass but not --moe-runner flashinfer_trtllm.

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@github-actions github-actions Bot added quant LLM Quantization deepseek labels Mar 24, 2026
leejnau added 5 commits March 25, 2026 14:30
…oe_runner

Dissolve FlashInferCuteDslMoE into the moe_runner pattern, aligning with
the MoE refactor roadmap (sgl-project#8715). CuteDSL FP4 now flows through FusedMoE
-> StandardDispatcher -> MoeRunner -> @register_fused_func, matching the
flashinfer_trtllm integration and unblocking future A2A backend support.
@nvpohanh
Copy link
Copy Markdown
Collaborator

nvpohanh commented Apr 7, 2026

@ch-wan could you review this? Thanks!

self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
if self.enable_flashinfer_cutedsl_moe:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we provide a default runner when it is auto (or not defined in server args)?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added FLASHINFER_TRTLLM as the default runner for auto: edf3e16

@ch-wan
Copy link
Copy Markdown
Collaborator

ch-wan commented Apr 8, 2026

/tag-and-rerun-ci

@github-actions github-actions Bot added the run-ci label Apr 8, 2026
@samuellees
Copy link
Copy Markdown
Contributor

samuellees commented Apr 9, 2026

/tag-and-rerun-ci (just move test forward, thx~)

@samuellees
Copy link
Copy Markdown
Contributor

Hi @leejnau , could you please give a fix for the failed ci case? thx! https://github.com/sgl-project/sglang/actions/runs/24141598046/job/70546144580?pr=21339#step:7:2420

@ch-wan ch-wan merged commit c554dc5 into sgl-project:main Apr 10, 2026
340 of 378 checks passed
@leejnau leejnau deleted the integrate_flashinfer_cutedsl_moe branch April 20, 2026 16:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants