Skip to content

[NPU]support model MiniCPM3-4B for npu#16866

Merged
iforgetmyname merged 13 commits intosgl-project:mainfrom
McZyWu:minicpm3-4B
Jan 24, 2026
Merged

[NPU]support model MiniCPM3-4B for npu#16866
iforgetmyname merged 13 commits intosgl-project:mainfrom
McZyWu:minicpm3-4B

Conversation

@McZyWu
Copy link
Copy Markdown
Contributor

@McZyWu McZyWu commented Jan 10, 2026

Motivation

Previously, npu does not support model minicpm3-4B.

The MiniCPM3-4B model was not supported by the NPU operators in the current codebase, making it impossible to run the model on NPU hardware. Additionally, enabling Data Parallelism (DP) for attention layers triggered a division-by-zero exception (when the first dimension of the DP attention tensor was 0), further blocking functional deployment of the model.

Modifications

As follows.
The changes in ascend backend are to solve not-support and accuracy problems.
The changes to the files rotary_embedding.py and minicpm3.py ensure full compatibility with DP, to resolve the division-by-zero exception when DP attention is enabled, eliminating runtime errors and enabling distributed execution.

Accuracy Tests

70% Accuracy on GSM8k 5-shot evaluation
image
image
image

test method:
change the model path in test_ascend_minicpm3_4b.py
python test/registered/ascend/llm_models/test_ascend_minicpm3_4b.py
Supported features: TP DP
Unsupported features: Graph mode, prefix tuning...

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @McZyWu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates the MiniCPM3-4B model with NPU hardware, significantly expanding the supported model ecosystem. The changes primarily involve adapting the attention mechanisms and rotary embedding calculations to leverage NPU capabilities, resulting in a verified 70% accuracy on the GSM8k benchmark. This enables efficient execution of the MiniCPM3-4B model on NPU devices.

Highlights

  • NPU Model Support: Added support for the MiniCPM3-4B model on NPU (Neural Processing Unit) hardware.
  • Accuracy Benchmark: Achieved 70% accuracy on the GSM8k 5-shot evaluation for the MiniCPM3-4B model on NPU.
  • Configuration Updates: Introduced new configuration parameters (qk_nope_head_dim, not_use_fused_infer_attention_score) to model_config.py for NPU-specific attention handling.
  • Attention Mechanism Adaptation: Implemented conditional logic in ascend_backend.py to switch between fused NPU attention and a custom scaled dot product attention based on the not_use_fused_infer_attention_score flag.
  • Tensor Reshaping Improvements: Updated tensor reshaping operations in rotary_embedding.py and minicpm3.py for compatibility and potential performance improvements.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for the MiniCPM3-4B model on NPU hardware. The changes involve adding a new model configuration, a fallback attention implementation for NPU when the fused kernel is not available, and some minor refactoring in related components. While the core logic for adding model support is sound, I've identified a critical bug in the new attention implementation and a few areas for improvement in terms of code style and clarity. The most significant issue is in ascend_backend.py, where the value tensor passed to scaled_dot_product_attention is incorrect, which will lead to wrong attention outputs. I've provided a detailed comment and suggestion to fix this. Other comments focus on improving code readability and maintainability.

Comment on lines +894 to +903
k_kpe = torch.cat([k_, v_], dim=-1).transpose(0, 1)
o_ = (
torch.nn.functional.scaled_dot_product_attention(
q_.unsqueeze(0),
k_kpe.unsqueeze(0),
v.unsqueeze(0),
enable_gqa=use_gqa,
scale=layer.scaling,
is_causal=True,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There appears to be a bug in how the value tensor is passed to scaled_dot_product_attention. The v tensor, which represents the value for the entire batch, is being used directly without being sliced for the current sequence within the loop. This will result in incorrect attention calculations as the value tensor will have mismatched sequence lengths and will not correspond to the correct sequence's keys.

For MiniCPM-like models, the value in attention is the kv_c part. In your implementation, k_ represents the kv_c slice for the current sequence. Therefore, k_ should be used as the value tensor. It needs to be correctly shaped for the scaled_dot_product_attention function.

I've suggested a fix that uses k_ as the value tensor, ensuring its dimensions are correctly aligned for the attention calculation.

Suggested change
k_kpe = torch.cat([k_, v_], dim=-1).transpose(0, 1)
o_ = (
torch.nn.functional.scaled_dot_product_attention(
q_.unsqueeze(0),
k_kpe.unsqueeze(0),
v.unsqueeze(0),
enable_gqa=use_gqa,
scale=layer.scaling,
is_causal=True,
)
k_kpe = torch.cat([k_, v_], dim=-1).transpose(0, 1)
v_ = k_.transpose(0, 1)
o_ = (
torch.nn.functional.scaled_dot_product_attention(
q_.unsqueeze(0),
k_kpe.unsqueeze(0),
v_.unsqueeze(0),
enable_gqa=use_gqa,
scale=layer.scaling,
is_causal=True,
)

Comment thread python/sglang/srt/layers/rotary_embedding.py
Comment thread python/sglang/srt/models/minicpm3.py
Comment thread python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py
assert (
layer.qk_head_dim != layer.v_head_dim
), "FIA only supports qk_head_dim != v_head_dim"
if not hasattr(self, "not_use_fused_infer_attention_score"):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if not_use_fused_infer_attention_score=False ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is set to be true.

@ping1jing2 ping1jing2 changed the title support model MiniCPM3-4B for npu, accuracy gsm8k 70% [NPU]support model MiniCPM3-4B for npu, accuracy gsm8k 70% Jan 20, 2026
@McZyWu McZyWu force-pushed the minicpm3-4B branch 2 times, most recently from f4a376d to 1019a2e Compare January 20, 2026 13:32
@McZyWu McZyWu requested a review from hnyls2002 as a code owner January 20, 2026 14:08
@McZyWu McZyWu force-pushed the minicpm3-4B branch 2 times, most recently from 98c552f to 02dc5c8 Compare January 21, 2026 12:51
@iforgetmyname
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

Comment thread python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py Outdated
@iforgetmyname iforgetmyname changed the title [NPU]support model MiniCPM3-4B for npu, accuracy gsm8k 70% [NPU]support model MiniCPM3-4B for npu Jan 24, 2026
@iforgetmyname iforgetmyname merged commit 8a5ed24 into sgl-project:main Jan 24, 2026
325 of 343 checks passed
Johnsonms pushed a commit to Johnsonms/sglang that referenced this pull request Feb 14, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants