Skip to content

[Feat][NVFP4] Enable NVFP4 MoE for Qwen series models (eg. Qwen3-Next) #13761#13761

Merged
Fridge003 merged 13 commits intosgl-project:mainfrom
samuellees:trtllm-moe-nvfp4
Nov 27, 2025
Merged

[Feat][NVFP4] Enable NVFP4 MoE for Qwen series models (eg. Qwen3-Next) #13761#13761
Fridge003 merged 13 commits intosgl-project:mainfrom
samuellees:trtllm-moe-nvfp4

Conversation

@samuellees
Copy link
Copy Markdown
Contributor

PR Dependency

Motivation

Enable NVFP4 MoE for Qwen series models (eg. Qwen3-Next) on Blackwell GPUs

TODO

  • Add unit test
  • Add command for reproduce accuracy results
  • Refactor this PR (Co-work with @kaixih PR13556)

Accuracy Tests

# Qwen3-Next: NVFP4 linear + NVFP4 MoE + FP8 Attention
export SGL_ENABLE_JIT_DEEPGEMM=false
python3 -m sglang.launch_server --model-path qwen3-next-80b-a3b-instruct-nvfp4-all --chunked-prefill-size 16384 --max-prefill-tokens 16384 --max-running-requests 512 --tp-size 4 --ep-size 4 --mem-fraction-static 0.7 --cuda-graph-bs 1 2 4 8 16 32 64 128 256 512 1024 --disable-radix-cache --log-level info --host 0.0.0.0 --port 8001 --random-seed 0  --quantization modelopt_fp4   --kv-cache-dtype fp8_e4m3  --moe-runner-backend flashinfer_trtllm --attention-backend trtllm_mha  --mamba-ssm-dtype bfloat16
Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 8 exact_match 0.9583 ± 0.0055
strict-match 8 exact_match 0.7324 ± 0.0122
# Qwen3-Next: FP8 linear + FP8 MoE + FP8 Attention
export SGL_ENABLE_JIT_DEEPGEMM=false
python3 -m sglang.launch_server --model-path Qwen3-Next/Qwen3-Next-80B-A3B-Instruct-FP8 --chunked-prefill-size 16384 --max-prefill-tokens 16384 --max-running-requests 512 --tp-size 4 --ep-size 4 --mem-fraction-static 0.7 --cuda-graph-bs 1 2 4 8 16 32 64 128 256 512 1024 --disable-radix-cache --log-level info --host 0.0.0.0 --port 8001 --random-seed 0  --quantization fp8    --kv-cache-dtype fp8_e4m3  --moe-runner-backend flashinfer_trtllm --attention-backend trtllm_mha  --mamba-ssm-dtype bfloat16 
Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 8 exact_match 0.9575 ± 0.0056
strict-match 8 exact_match 0.8370 ± 0.0102

I also reproduced the Qwen3-30B-A3B NVFP4 accuracy result as it was shown in PR13556

Accuracy: 0.899
Invalid: 0.000
Latency: 29.658 s
Output throughput: 5636.438 token/s

Modifications

Benchmarking and Profiling

Checklist

samuellees and others added 3 commits November 21, 2025 21:45
Co-authored-by: Sam Li <lsam@nvidia.com>
Co-authored-by: Kaixi Hou <kaixih@nvidia.com>
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @samuellees, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces support for NVFP4 Mixture-of-Experts (MoE) for Qwen series models, optimized for Blackwell GPUs. The changes involve adapting the MoE layer to dynamically configure routing parameters and updating the server's argument parsing to correctly handle modelopt_fp4 quantization. A new unit test has been added to ensure the stability and accuracy of this new configuration, demonstrating its readiness for deployment.

Highlights

  • NVFP4 MoE Support: Enabled NVFP4 Mixture-of-Experts (MoE) for Qwen series models, specifically targeting Blackwell GPUs, to enhance performance and efficiency.
  • Dynamic Routing Configuration: Modified the MoE layer to dynamically handle routing methods and correction biases, removing hardcoded values and improving flexibility.
  • Quantization Backend Integration: Updated server arguments to correctly recognize modelopt_fp4 quantization for automatic MoE runner backend selection, ensuring proper configuration.
  • New Unit Test and Validation: Introduced a dedicated unit test to validate the NVFP4 MoE functionality for Qwen3-30B-A3B models, including accuracy verification, and integrated it into the nightly test suite.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enables NVFP4 MoE for Qwen series models on Blackwell GPUs. The changes involve updating server arguments to recognize modelopt_fp4 quantization for MoE backend selection and generalizing the MoE layer to support different routing methods. A new nightly test is added to verify the functionality and accuracy. The changes appear correct and well-tested. I've identified a minor code duplication issue in server_args.py that could be refactored for better maintainability.

@samuellees samuellees changed the title Trtllm moe nvfp4 [Feat][NVFP4] Enable NVFP4 MoE for Qwen series models (eg. Qwen3-Next) #13427 Nov 22, 2025
@samuellees samuellees changed the title [Feat][NVFP4] Enable NVFP4 MoE for Qwen series models (eg. Qwen3-Next) #13427 [Feat][NVFP4] Enable NVFP4 MoE for Qwen series models (eg. Qwen3-Next) #13761 Nov 22, 2025
Comment thread python/sglang/srt/server_args.py Outdated
self.quantization = quant_method
if (
self.quantization == "fp8"
(self.quantization == "fp8" or self.quantization == "modelopt_fp4")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe self.quantization in ("fp8", "modelopt_fp4")?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@b8zhong b8zhong added the run-ci label Nov 23, 2025
@Fridge003 Fridge003 merged commit 91e8dc3 into sgl-project:main Nov 27, 2025
108 of 128 checks passed
@samuellees samuellees deleted the trtllm-moe-nvfp4 branch November 27, 2025 01:24
harvenstar pushed a commit to harvenstar/sglang that referenced this pull request Dec 4, 2025
@samuellees samuellees mentioned this pull request Feb 11, 2026
12 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants