Skip to content

[NPU][Bugfix] fix Qwen3-VL-30B-A3B-Instruct accuracy loss#15597

Merged
iforgetmyname merged 9 commits intosgl-project:mainfrom
cen121212:main-12-22
Dec 31, 2025
Merged

[NPU][Bugfix] fix Qwen3-VL-30B-A3B-Instruct accuracy loss#15597
iforgetmyname merged 9 commits intosgl-project:mainfrom
cen121212:main-12-22

Conversation

@cen121212
Copy link
Copy Markdown
Contributor

@cen121212 cen121212 commented Dec 22, 2025

Motivation

Add the processing for mRoPE for Qwen VL

Modifications

Add the processing for mRoPE in get_cos_sin_with_position

python/sglang/srt/layers/rotary_embedding.py    get_cos_sin_with_position

Accuracy Tests

test command:

nohup evalscope eval \
 --model /mnt/share/weights/Qwen3-VL-30B-A3B-Instruct \
 --api-url http://127.0.0.1:22001/v1 \
 --api-key EMPTY \
 --eval-type openai_api \
 --generation-config '{"max_tokens":20000, "top_p":0.8, "top_k":20, "temperature": 0.7, "n":1, "timeout": 60, "stream": true,"do_sample": true,"presence_penalty": 1.5, "repetition_penalty": 1,"seed": "3407"}' \
 --datasets  mm_star \
 --dataset-hub Local \
 --dataset-args '{"mm_star": {"local_path": "/home/***/MMStar"}}' \
 --eval-batch-size 5 \
 --ignore-errors \
 > evalscope.log 2>&1 &

Before this PR:
+---------------------------+-----------+----------+-------------------------+-------+---------+---------+
| Model | Dataset | Metric | Subset | Num | Score | Cat.0 |
+===========================+===========+==========+=========================+=======+=========+=========+
| Qwen3-VL-30B-A3B-Instruct | mm_star | mean_acc | coarse perception | 250 | 0.648 | default |
+---------------------------+-----------+----------+-------------------------+-------+---------+---------+
| Qwen3-VL-30B-A3B-Instruct | mm_star | mean_acc | fine-grained perception | 250 | 0.504 | default |
+---------------------------+-----------+----------+-------------------------+-------+---------+---------+
| Qwen3-VL-30B-A3B-Instruct | mm_star | mean_acc | instance reasoning | 250 | 0.676 | default |
+---------------------------+-----------+----------+-------------------------+-------+---------+---------+
| Qwen3-VL-30B-A3B-Instruct | mm_star | mean_acc | logical reasoning | 250 | 0.624 | default |
+---------------------------+-----------+----------+-------------------------+-------+---------+---------+
| Qwen3-VL-30B-A3B-Instruct | mm_star | mean_acc | math | 250 | 0.748 | default |
+---------------------------+-----------+----------+-------------------------+-------+---------+---------+
| Qwen3-VL-30B-A3B-Instruct | mm_star | mean_acc | science & technology | 250 | 0.388 | default |
+---------------------------+-----------+----------+-------------------------+-------+---------+---------+
| Qwen3-VL-30B-A3B-Instruct | mm_star | mean_acc | OVERALL | 1500 | 0.598 | - |
+---------------------------+-----------+----------+-------------------------+-------+---------+---------+

After this PR:
+---------------------------+-----------+----------+-------------------------+-------+---------+---------+
| Model | Dataset | Metric | Subset | Num | Score | Cat.0 |
+===========================+===========+==========+=========================+=======+=========+=========+
| Qwen3-VL-30B-A3B-Instruct | mm_star | mean_acc | coarse perception | 250 | 0.668 | default |
+---------------------------+-----------+----------+-------------------------+-------+---------+---------+
| Qwen3-VL-30B-A3B-Instruct | mm_star | mean_acc | fine-grained perception | 250 | 0.544 | default |
+---------------------------+-----------+----------+-------------------------+-------+---------+---------+
| Qwen3-VL-30B-A3B-Instruct | mm_star | mean_acc | instance reasoning | 250 | 0.744 | default |
+---------------------------+-----------+----------+-------------------------+-------+---------+---------+
| Qwen3-VL-30B-A3B-Instruct | mm_star | mean_acc | logical reasoning | 250 | 0.748 | default |
+---------------------------+-----------+----------+-------------------------+-------+---------+---------+
| Qwen3-VL-30B-A3B-Instruct | mm_star | mean_acc | math | 250 | 0.816 | default |
+---------------------------+-----------+----------+-------------------------+-------+---------+---------+
| Qwen3-VL-30B-A3B-Instruct | mm_star | mean_acc | science & technology | 250 | 0.508 | default |
+---------------------------+-----------+----------+-------------------------+-------+---------+---------+
| Qwen3-VL-30B-A3B-Instruct | mm_star | mean_acc | OVERALL | 1500 | 0.6713 | - |
+---------------------------+-----------+----------+-------------------------+-------+---------+---------+

Benchmarking and Profiling

Checklist

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @cen121212, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This Work In Progress pull request aims to resolve an observed accuracy loss when running the Qwen3-VL-30B-A3B-Instruct model on NPU hardware. The changes primarily involve refining how expert weights are normalized within the Mixture of Experts layer and strategically disabling NPU-specific code paths for Rotary Embedding and the model's forward preparation, opting for more generic or native implementations to mitigate the accuracy degradation.

Highlights

  • NPU MoE Weight Normalization Fix: Corrected the renormalization of topk_weights within the NPU Mixture of Experts (MoE) implementation. The l1_norm function was replaced with a direct sum-based normalization (topk_weights / topk_weights.sum(dim=-1, keepdim=True)) to ensure proper weight distribution and address potential accuracy issues.
  • NPU Rotary Embedding Bypass: The NPU-specific implementation for Rotary Embedding (_forward_npu) has been temporarily disabled by commenting out its conditional call. This change forces the system to fall back to the more generic _forward_native method, likely as a workaround to improve stability or accuracy on NPU.
  • NPU Qwen3 MoE Forward Prepare Bypass: The forward_prepare method for the Qwen3 MoE model now unconditionally calls forward_prepare_native. This bypasses any NPU-specific logic that might have been present or intended for this critical preparation step, aiming to resolve accuracy degradation.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request modifies NPU-related code paths across three files. In topk.py, a custom l1_norm for topk_weights was replaced with a standard PyTorch sum-based normalization, but the reviewer noted potential numerical instability, the removal of specific expert handling logic, and requested the removal of commented-out code. Additionally, NPU-specific code paths were commented out in rotary_embedding.py and effectively bypassed in qwen3_moe.py, prompting the reviewer to ask for clarification on whether these changes are temporary and to ensure proper restoration or removal with explanation to prevent maintenance issues.

Comment on lines +34 to +39
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
# topk_weights = l1_norm(
# topk_weights
# if topk_config.num_fused_shared_experts == 0
# else topk_weights[:, :-1]
# )
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This change replaces the custom l1_norm with a standard PyTorch implementation, which is good for clarity. However, there are a few points to consider:

  1. Numerical Stability: The division could be unstable if topk_weights.sum() is zero. Adding a small epsilon to the denominator would make it more robust.
  2. Logic Change: The original code had special handling for topk_config.num_fused_shared_experts, excluding the last expert from normalization. This logic is now removed. If this is intentional, a comment explaining the reasoning would be helpful. If not, this could be a regression.
  3. Code Cleanliness: The commented-out old code should be removed before merging.

Here is a suggestion addressing points 1 and 3:

Suggested change
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
# topk_weights = l1_norm(
# topk_weights
# if topk_config.num_fused_shared_experts == 0
# else topk_weights[:, :-1]
# )
topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-9)

Comment on lines +1527 to +1528
# elif _is_npu:
# return self._forward_npu(positions, query, key)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This NPU-specific code path has been commented out. If this is a temporary change for debugging, please ensure it's either restored or removed with a proper explanation before the final merge. Leaving commented-out code can lead to confusion and maintenance issues.

@cen121212 cen121212 changed the title 【WIP】fix: NPU Qwen3-VL-30B-A3B-Instruct accuracy loss fix: NPU Qwen3-VL-30B-A3B-Instruct accuracy loss Dec 24, 2025
@iforgetmyname iforgetmyname changed the title fix: NPU Qwen3-VL-30B-A3B-Instruct accuracy loss [NPU][Bugfix] fix Qwen3-VL-30B-A3B-Instruct accuracy loss Dec 26, 2025
@iforgetmyname
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

Liwansi added a commit to iforgetmyname/sglang that referenced this pull request Dec 29, 2025
…glang into eagle-sche

* 'ifmn/eagle-dp-attn' of https://github.com/sgl-project/sglang: (22 commits)
  dp scheduler enhance support with chunked prefill (sgl-project#16071)
  modify suffix decoding
  CI dependency update (sgl-project#16063)
  fix rotary_embedding init npu (sgl-project#16011)
  feat: bugfix and accuracy fix for stablelm2_1_6b (sgl-project#15932)
  Update model and feature support for Ascend NPU (sgl-project#16005)
  Bugfix for Llama4 (sgl-project#15929)
  Bugfix for ds-vl2 (sgl-project#15894)
  gme qwen vl runners fix (sgl-project#15899)
  add profiling in scheduler (sgl-project#15876)
  llama use triton rope op (sgl-project#15855)
  suffix decoding adapt npu
  suffix decoding adapt npu
  Add suffix decoding speculative algorithm from feature 13553
  cherry sgl-project#15434: qwen3 vl performance update
  cherry sgl-project#15597: fix Qwen3-VL-30B-A3B-Instruct accuracy loss
  [Schedule] bug fix for schedule enhancer (sgl-project#15834)
  minilb support roundrobin (sgl-project#15824)
  fix torchair compile issue
  cherry sgl-project#15187: lora fix
  ...

# Conflicts:
#	python/sglang/srt/managers/scheduler.py
#	python/sglang/srt/managers/scheduler_enhancer.py
@iforgetmyname iforgetmyname merged commit 25b4856 into sgl-project:main Dec 31, 2025
275 of 326 checks passed
YChange01 pushed a commit to YChange01/sglang that referenced this pull request Jan 13, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants