Skip to content

[NPU] Support MTP for Qwen3.5#20918

Merged
sglang-npu-bot merged 36 commits intosgl-project:mainfrom
iridiumine:feature/ascend-mtp-adapt
Apr 27, 2026
Merged

[NPU] Support MTP for Qwen3.5#20918
sglang-npu-bot merged 36 commits intosgl-project:mainfrom
iridiumine:feature/ascend-mtp-adapt

Conversation

@iridiumine
Copy link
Copy Markdown
Contributor

@iridiumine iridiumine commented Mar 19, 2026

Motivation

Adapt the MTP (Multi-Token Prediction) speculative decoding feature for the Qwen3.5 model on the Ascend NPU platform, fix inference errors, and ensure stable and efficient model operation.

Modifications

  1. Add a dedicated GDN attention backend tailored for Ascend NPU, designed to address hardware-specific compatibility and performance needs;
  2. Complete end-to-end MTP speculative decoding adaptation for the Qwen3.5 model on Ascend NPU, ensuring full functionality and stable inference;
  3. Adjust attention backend routing logic to automatically switch to the NPU-specific GDN backend when running on Ascend hardware.

Accuracy Tests

Script

# high performance cpu
echo performance | tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor
sysctl -w vm.swappiness=0
sysctl -w kernel.numa_balancing=0
sysctl -w kernel.sched_migration_cost_ns=50000
# bind cpu
export SGLANG_SET_CPU_AFFINITY=1

unset https_proxy
unset http_proxy
unset HTTPS_PROXY
unset HTTP_PROXY
export ASCEND_LAUNCH_BLOCKING=1
# cann
source /usr/local/Ascend/ascend-toolkit/set_env.sh
source /usr/local/Ascend/nnal/atb/set_env.sh

export PYTHONPATH=${PWD}/python:$PYTHONPATH

export STREAMS_PER_DEVICE=32
export SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK=32
export HCCL_BUFFSIZE=3000
export HCCL_OP_EXPANSION_MODE=AIV
export HCCL_SOCKET_IFNAME=lo
export GLOO_SOCKET_IFNAME=lo
export SGLANG_NPU_PROFILING=0
export SGLANG_NPU_PROFILING_STAGE="prefill"
export DEEPEP_NORMAL_LONG_SEQ_ROUND=32
export DEEPEP_NORMAL_LONG_SEQ_PER_ROUND_TOKENS=3584
export ASCEND_MF_STORE_URL="tcp://127.0.0.1:24669"
export SGLANG_DISAGGREGATION_WAITING_TIMEOUT=3600
export SGLANG_ENABLE_SPEC_V2=1
export SGLANG_ENABLE_OVERLAP_PLAN_STREAM=1
python3 -m sglang.launch_server \
        --model-path /home/weights/Qwen3.5-27B-W8A8 \
        --attention-backend ascend \
        --device npu \
        --tp-size 4 --nnodes 1 --node-rank 0 \
        --chunked-prefill-size -1 --max-prefill-tokens 16384 \
        --disable-radix-cache \
        --trust-remote-code \
        --host 127.0.0.1 --max-running-requests 60 --max-mamba-cache-size 60 \
        --mem-fraction-static 0.9 \
        --port 8000 \
        --cuda-graph-bs 4 8 12 16 20 24 28 32 36 40 44 48 52 56 60 \
        --enable-multimodal \
        --quantization modelslim \
        --mm-attention-backend ascend_attn \
        --dtype bfloat16 --mamba-ssm-dtype bfloat16 --max-total-tokens 800000 \
        --speculative-algorithm NEXTN --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4 

Result

+------------------+-----------+----------+----------+-------+---------+---------+
| Model            | Dataset   | Metric   | Subset   |   Num |   Score | Cat.0   |
+==================+===========+==========+==========+=======+=========+=========+
| Qwen3.5-27B-W8A8 | gsm8k     | mean_acc | main     |   256 |  0.9805 | default |
+------------------+-----------+----------+----------+-------+---------+---------+

Benchmarking and Profiling

No performance impact.

Checklist

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the SGLang framework by enabling Multi-Token Prediction (MTP) speculative decoding for the Qwen3.5 model on Ascend NPU. It introduces a specialized GDN attention backend, refines attention routing, and integrates NPU-specific memory and unquantization logic to ensure efficient and stable model operation on Ascend hardware. These changes are crucial for leveraging the performance capabilities of NPUs for advanced language model inference.

Highlights

  • Ascend NPU GDN Attention Backend: A dedicated GDN (Gated Delta Network) attention backend has been introduced specifically for the Ascend NPU platform, addressing hardware-specific compatibility and performance requirements.
  • Qwen3.5 MTP Speculative Decoding Adaptation: End-to-end Multi-Token Prediction (MTP) speculative decoding has been fully adapted for the Qwen3.5 model on Ascend NPU, ensuring stable and functional inference.
  • Dynamic Attention Backend Routing: The attention backend routing logic has been updated to automatically select the NPU-specific GDN backend when operating on Ascend hardware.
  • NPU-Specific Memory Management: Memory pool initialization for convolutional states (conv_state) now includes NPU-specific handling, particularly for speculative decoding with draft tokens.
  • NPU Unquantization Logic for MTP Models: Environment variables are dynamically set to manage unquantization for MTP models (Qwen3.5 and Qwen3Next) when running on NPU without quantization, ensuring correct data types for NPU operations.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for MTP (Multi-Token Prediction) speculative decoding for Qwen3.5 models on Ascend NPUs. The changes include a new NPU-specific GDN attention backend, updates to memory management and metadata handling, and model-level modifications to enable this functionality. My review identified two critical thread-safety issues in the model forward passes related to the modification of global environment variables, which could lead to race conditions in a server environment. Additionally, a minor performance issue due to a redundant computation was found in the new attention backend.

Comment thread python/sglang/srt/models/qwen3_5_mtp.py Outdated
Comment on lines +126 to +129
if is_npu() and self.quant_config is None:
# ascend mtp unquant
os.environ["SGLANG_DEEPEP_BF16_DISPATCH"] = "1"
os.environ["DEEP_NORMAL_MODE_USE_INT8_QUANT"] = "0"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Modifying environment variables (os.environ) within the forward method is not thread-safe and can lead to race conditions in a concurrent server environment. When multiple requests are processed in parallel, one request might change the environment variables while another is in the middle of its execution, leading to incorrect and unpredictable behavior. This is a critical issue for a production server.

Configuration should ideally be passed through function arguments or other thread-safe mechanisms. If the underlying DEEPEP library only supports configuration via environment variables, this section of code should be protected by a lock, though this would have a significant performance impact. The best approach would be to investigate if the library can be configured in a thread-safe manner.

Comment thread python/sglang/srt/models/qwen3_5_mtp.py Outdated
Comment on lines +161 to +164
if is_npu() and self.quant_config is None:
# ascend mtp unquant
os.environ["SGLANG_DEEPEP_BF16_DISPATCH"] = "0"
os.environ["DEEP_NORMAL_MODE_USE_INT8_QUANT"] = "1"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Similar to the modification at the beginning of the forward method, resetting environment variables here is not thread-safe and can cause race conditions. A request could reset the variables while another concurrent request requires them to be set, leading to incorrect behavior. This pattern of setting and unsetting environment variables per-request is unsafe in a multithreaded context.

Comment on lines +93 to +96
if is_npu() and self.quant_config is None:
# ascend mtp unquant
os.environ["SGLANG_DEEPEP_BF16_DISPATCH"] = "1"
os.environ["DEEP_NORMAL_MODE_USE_INT8_QUANT"] = "0"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Modifying environment variables (os.environ) within the forward method is not thread-safe. In a concurrent server environment, this can lead to race conditions where one request's configuration leaks into another, causing unpredictable behavior. This is a critical issue.

Please consider a thread-safe way to pass this configuration. If the underlying library absolutely requires environment variables, access to this part of the code might need to be serialized (e.g., with a lock), which would be a major performance bottleneck. The preferred solution is to avoid environment variables for per-request configuration.

Comment on lines +114 to +117
if is_npu() and self.quant_config is None:
# ascend mtp unquant
os.environ["SGLANG_DEEPEP_BF16_DISPATCH"] = "0"
os.environ["DEEP_NORMAL_MODE_USE_INT8_QUANT"] = "1"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Resetting environment variables here shares the same critical thread-safety issue as setting them at the start of the method. This can lead to race conditions in a concurrent server, where one request undoes the configuration needed by another. This pattern is unsafe and should be refactored.

Comment thread python/sglang/srt/hardware_backend/npu/attention/ascend_gdn_backend.py Outdated
@iridiumine iridiumine requested a review from yuan-luo as a code owner March 24, 2026 09:38


@triton.jit
def fused_gdn_gating_kernel_without_sigmoid_kernel(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this op Ascend only ? If it is only used in Ascend, please move it to sgl-kernel-npu repo

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i have moved it to sgl-kernel-npu repo(sgl-project/sgl-kernel-npu#429)

req_pool_indices[bs - num_padding :] = 0
mamba_indices = self.req_to_token_pool.get_mamba_indices(req_pool_indices)
mamba_indices[bs - num_padding :] = -1
mamba_indices[bs - num_padding :] = 0
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change needs to be reviewed to determine whether it affects the GPU implementation !

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change does not affect the GPU implementation.

and get_global_server_args().quantization is not None
):
# ascend mtp unquant
os.environ["SGLANG_DEEPEP_BF16_DISPATCH"] = "1"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using env.get

@cjy0x
Copy link
Copy Markdown

cjy0x commented Apr 22, 2026

Sorry to bother you, but I wonder what‘s your triton-ascend version.

I test this pr on sglang v0.5.10 with triton-ascend v3.2.0 on a3 machine, and I encountered the following error:

(SGLangEngine pid=3061557) [2026-04-22 09:32:22 TP2] Scheduler hit an exception: Traceback (most recent call last):
(SGLangEngine pid=3061557)   File "/root/anaconda3/envs/_slime_/lib/python3.11/site-packages/triton/compiler/compiler.py", line 310, in compile
(SGLangEngine pid=3061557)     next_module = compile_ir(module, metadata)
(SGLangEngine pid=3061557)                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(SGLangEngine pid=3061557)   File "/root/anaconda3/envs/_slime_/lib/python3.11/site-packages/triton/backends/ascend/compiler.py", line 778, in <lambda>
(SGLangEngine pid=3061557)     lambda src, metadata: linalg_to_bin_enable_npu_compile_A2_A3(
(SGLangEngine pid=3061557)                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(SGLangEngine pid=3061557)   File "/root/anaconda3/envs/_slime_/lib/python3.11/site-packages/triton/backends/ascend/compiler.py", line 549, in linalg_to_bin_enable_npu_compile_A2_A3
(SGLangEngine pid=3061557)     ret = subprocess.run(cmd_list, capture_output=True, check=True)
(SGLangEngine pid=3061557)           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(SGLangEngine pid=3061557)   File "/root/anaconda3/envs/_slime_/lib/python3.11/subprocess.py", line 571, in run
(SGLangEngine pid=3061557)     raise CalledProcessError(retcode, process.args,
(SGLangEngine pid=3061557) subprocess.CalledProcessError: Command '['/home//851b080/cann-8.5.1/bin/bishengir-compile', '/tmp/tmpp8bki1m1/kernel.ttadapter.mlir', '--target=Ascend910_9392', '--enable-auto-multi-buffer=True', '--enable-auto-bind-sub-block=True', '--enable-hfusion-compile=true', '--enable-hivm-compile=true', '--enable-triton-kernel-compile=true', '-o', '/tmp/tmpp8bki1m1/kernel']' returned non-zero exit status 1.
(SGLangEngine pid=3061557)
(SGLangEngine pid=3061557) During handling of the above exception, another exception occurred:
(SGLangEngine pid=3061557)
(SGLangEngine pid=3061557) Traceback (most recent call last):
(SGLangEngine pid=3061557)   File "/home//slime-proj/sglang/python/sglang/srt/managers/scheduler.py", line 3623, in run_scheduler_process
(SGLangEngine pid=3061557)     scheduler.run_event_loop()
(SGLangEngine pid=3061557)   File "/home//slime-proj/sglang/python/sglang/srt/managers/scheduler.py", line 1307, in run_event_loop
(SGLangEngine pid=3061557)     dispatch_event_loop(self)
(SGLangEngine pid=3061557)   File "/home//slime-proj/sglang/python/sglang/srt/managers/scheduler.py", line 3506, in dispatch_event_loop
(SGLangEngine pid=3061557)     scheduler.event_loop_normal()
(SGLangEngine pid=3061557)   File "/root/anaconda3/envs/_slime_/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
(SGLangEngine pid=3061557)     return func(*args, **kwargs)
(SGLangEngine pid=3061557)            ^^^^^^^^^^^^^^^^^^^^^
(SGLangEngine pid=3061557)   File "/home//slime-proj/sglang/python/sglang/srt/managers/scheduler.py", line 1326, in event_loop_normal
(SGLangEngine pid=3061557)     result = self.run_batch(batch)
(SGLangEngine pid=3061557)              ^^^^^^^^^^^^^^^^^^^^^
(SGLangEngine pid=3061557)   File "/home//slime-proj/sglang/python/sglang/srt/managers/scheduler.py", line 2731, in run_batch
(SGLangEngine pid=3061557)     batch_result = self.model_worker.forward_batch_generation(
(SGLangEngine pid=3061557)                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(SGLangEngine pid=3061557)   File "/home//slime-proj/sglang/python/sglang/srt/speculative/eagle_worker.py", line 320, in forward_batch_generation
(SGLangEngine pid=3061557)     self.verify(batch, spec_info)
(SGLangEngine pid=3061557)   File "/home//slime-proj/sglang/python/sglang/srt/speculative/eagle_worker.py", line 782, in verify
(SGLangEngine pid=3061557)     self._mamba_verify_update(
(SGLangEngine pid=3061557)   File "/home//slime-proj/sglang/python/sglang/srt/speculative/eagle_worker.py", line 874, in _mamba_verify_update
(SGLangEngine pid=3061557)     self.target_worker.model_runner.attn_backend.update_mamba_state_after_mtp_verify(
(SGLangEngine pid=3061557)   File "/home//slime-proj/sglang/python/sglang/srt/hardware_backend/npu/attention/ascend_hybrid_linear_attn_backend.py", line 259, in update_mamba_state_after_mtp_verify
(SGLangEngine pid=3061557)     move_intermediate_cache(
(SGLangEngine pid=3061557)   File "/root/anaconda3/envs/_slime_/lib/python3.11/site-packages/sgl_kernel_npu/mamba/mamba_state_update_triton.py", line 126, in move_intermediate_cache
(SGLangEngine pid=3061557)     move_cache_dynamic_last_kernel_h_block[grid](
(SGLangEngine pid=3061557)   File "/root/anaconda3/envs/_slime_/lib/python3.11/site-packages/triton/runtime/jit.py", line 353, in <lambda>
(SGLangEngine pid=3061557)     return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
(SGLangEngine pid=3061557)                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(SGLangEngine pid=3061557)   File "/root/anaconda3/envs/_slime_/lib/python3.11/site-packages/triton/runtime/jit.py", line 660, in run
(SGLangEngine pid=3061557)     kernel = self.compile(
(SGLangEngine pid=3061557)              ^^^^^^^^^^^^^
(SGLangEngine pid=3061557)   File "/root/anaconda3/envs/_slime_/lib/python3.11/site-packages/triton/compiler/compiler.py", line 320, in compile
(SGLangEngine pid=3061557)     raise MLIRCompilationError(stage_name, error_detail)
(SGLangEngine pid=3061557) triton.compiler.errors.MLIRCompilationError:
(SGLangEngine pid=3061557) ///------------------[ERROR][Triton][BEG]------------------
(SGLangEngine pid=3061557) [ConvertLinalgRToBinary] encounters error:
(SGLangEngine pid=3061557) loc("/tmp/tmpp8bki1m1/kernel.ttadapter.mlir":1:1): error: Failed to run BiShengHIR pipeline
(SGLangEngine pid=3061557)
(SGLangEngine pid=3061557) loc("/tmp/tmpp8bki1m1/kernel.ttadapter.mlir":2:3): error: ub overflow, requires 2097152 bits while 1572864 bits available! (possible reason: tiling basic block is too large or block number is more than what user expect due to multi-buffer feature is enabled and some ops need extra local buffer.)
(SGLangEngine pid=3061557) loc("/tmp/tmpp8bki1m1/kernel.ttadapter.mlir":1:1): error: Failed to run BiShengHIR pipeline
(SGLangEngine pid=3061557)
(SGLangEngine pid=3061557) loc("/tmp/tmpp8bki1m1/kernel.ttadapter.mlir":2:3): error: ub overflow, requires 2097152 bits while 1572864 bits available! (possible reason: tiling basic block is too large or block number is more than what user expect due to multi-buffer feature is enabled and some ops need extra local buffer.)
(SGLangEngine pid=3061557) loc("/tmp/tmpp8bki1m1/kernel.ttadapter.mlir":1:1): error: Failed to run BiShengHIR pipeline
(SGLangEngine pid=3061557)
(SGLangEngine pid=3061557) loc("/tmp/tmpp8bki1m1/kernel.ttadapter.mlir":2:3): error: ub overflow, requires 2097152 bits while 1572864 bits available! (possible reason: tiling basic block is too large or block number is more than what user expect due to multi-buffer feature is enabled and some ops need extra local buffer.)
(SGLangEngine pid=3061557) loc("/tmp/tmpp8bki1m1/kernel.ttadapter.mlir":1:1): error: Failed to run BiShengHIR pipeline
(SGLangEngine pid=3061557)
(SGLangEngine pid=3061557) loc("/tmp/tmpp8bki1m1/kernel.ttadapter.mlir":2:3): error: ub overflow, requires 2097152 bits while 1572864 bits available! (possible reason: tiling basic block is too large or block number is more than what user expect due to multi-buffer feature is enabled and some ops need extra local buffer.)
(SGLangEngine pid=3061557) loc("/tmp/tmpp8bki1m1/kernel.ttadapter.mlir":1:1): error: Failed to run BiShengHIR pipeline
(SGLangEngine pid=3061557)
(SGLangEngine pid=3061557) loc("/tmp/tmpp8bki1m1/kernel.ttadapter.mlir":2:3): error: ub overflow, requires 2097152 bits while 1572864 bits available! (possible reason: tiling basic block is too large or block number is more than what user expect due to multi-buffer feature is enabled and some ops need extra local buffer.)
(SGLangEngine pid=3061557) [ERROR] Failed to run BiShengIR pipeline
(SGLangEngine pid=3061557)
(SGLangEngine pid=3061557)
(SGLangEngine pid=3061557) [INFO]: The compiled kernel cache is in /root/.triton/cache/xaNvcZrMuMy_f7dS6Iq88V2w0VfI4uQln5MWDIdXUQM
(SGLangEngine pid=3061557)
(SGLangEngine pid=3061557) ///------------------[ERROR][Triton][END]------------------

do you have any ideas?

@sglang-npu-bot
Copy link
Copy Markdown
Collaborator

/tag-run-ci-label

@sglang-npu-bot
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@sglang-npu-bot
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

1 similar comment
@sglang-npu-bot
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@sglang-npu-bot sglang-npu-bot merged commit 32c3513 into sgl-project:main Apr 27, 2026
372 of 438 checks passed
@iridiumine iridiumine deleted the feature/ascend-mtp-adapt branch April 27, 2026 03:15
@OrangeRedeng
Copy link
Copy Markdown
Contributor

OrangeRedeng commented Apr 27, 2026

@iridiumine Hi! Which version of sgl-kernel-npu you're using? It seems i have problems with our default 2026.03.10.rc1 version, i think we need update it in https://github.com/sgl-project/sglang/blob/main/scripts/ci/npu/npu_ci_install_dependency.sh
image

@iridiumine
Copy link
Copy Markdown
Contributor Author

@OrangeRedeng Hi, I’m using the following sgl-kernel-npu version:https://github.com/sgl-project/sgl-kernel-npu/releases/tag/2026.04.15.rc4
As for the CI script update, I’ll follow up and discuss it with you later.

vguduruTT pushed a commit to vguduruTT/sglang that referenced this pull request May 2, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants