Skip to content

feat(Qwen3.5): hybrid linear attention support PP+PD#19254

Open
zhangxiaolei123456 wants to merge 22 commits intosgl-project:mainfrom
bytedance-iaas:main_qwen3.5_0224_fix_pp
Open

feat(Qwen3.5): hybrid linear attention support PP+PD#19254
zhangxiaolei123456 wants to merge 22 commits intosgl-project:mainfrom
bytedance-iaas:main_qwen3.5_0224_fix_pp

Conversation

@zhangxiaolei123456
Copy link
Copy Markdown
Contributor

@zhangxiaolei123456 zhangxiaolei123456 commented Feb 24, 2026

Motivation

hybrid linear attention do not support PP and PD

PP:

`GLOO_SOCKET_IFNAME=eth0 NCCL_MIN_NCHANNELS=24 NCCL_IB_QPS_PER_CONNECTION=8 SGLANG_PP_LAYER_PARTITION="15,15,15,15" python -m sglang.launch_server --model-path /data00/models/Qwen3.5-397B-A17B-FP8 --port 8000 --tp-size 2 --pp-size 4 --mem-fraction-static 0.85 --reasoning-parser qwen3 --tool-call-parser qwen3_coder --mamba-ssm-dtype float16 --kv-cache-dtype fp8_e4m3 --host 0.0.0.0 --port 8000 --disable-radix-cache --pp-async-batch-depth 1 --pp-max-micro-batch-size 1 --max-running-requests 128  --chunked-prefill-size 4096 --max-prefill-tokens 16384 --page-size 64`

PP+PD:
Prefill:

`GLOO_SOCKET_IFNAME=eth0 NCCL_MIN_NCHANNELS=24 NCCL_IB_QPS_PER_CONNECTION=8 SGLANG_PP_LAYER_PARTITION="15,15,15,15" python -m sglang.launch_server --model-path /data00/models/Qwen3.5-397B-A17B-FP8 --port 8000 --tp-size 2 --pp-size 4 --mem-fraction-static 0.85 --reasoning-parser qwen3 --tool-call-parser qwen3_coder --mamba-ssm-dtype float16 --kv-cache-dtype fp8_e4m3 --disaggregation-mode prefill  --disaggregation-ib-device  "mlx5_1,mlx5_2,mlx5_3,mlx5_4" --host 0.0.0.0 --port 30300 --disable-radix-cache --pp-async-batch-depth 1 --pp-max-micro-batch-size 1 --max-running-requests 128  --chunked-prefill-size 8192 --max-prefill-tokens 16384 --page-size 64`

Decode

`GLOO_SOCKET_IFNAME=eth0 NCCL_MIN_NCHANNELS=24 NCCL_IB_QPS_PER_CONNECTION=8 SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK=128 python -m sglang.launch_server --model-path /data00/models/Qwen3.5-397B-A17B-FP8 --port 8000 --tp-size 8 --ep-size 8 --mem-fraction-static 0.75 --context-length 131072 --reasoning-parser qwen3 --tool-call-parser qwen3_coder --cuda-graph-bs 1 8 16 32 64 --max-running-requests 256 --enable-dp-attention --dp-size 4 --mamba-ssm-dtype float16 --kv-cache-dtype fp8_e4m3 --disaggregation-mode decode  --disaggregation-ib-device  "mlx5_1,mlx5_2,mlx5_3,mlx5_4" --moe-runner-backend deep_gemm --moe-a2a-backend deepep --deepep-mode low_latency --host 0.0.0.0 --port 30300 --speculative-algo EAGLE --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4 --enable-metrics --disable-radix-cache --page-size 64`

Modifications

PR input output chunked-prefill-size mappings QPS max_concurrency num promot TTFT (ms) TPOT(ms) tps/gpu
TP8 3500 1 16384 LInear and GQA 7 256 1024 27239.36 NA 2840.00
PP4TP2 3500 1 16384 LInear and GQA 10 256 1024 7881.33   3997.00

Performance improve 40%

Accuracy Tests

GSM8K
PP:

python3 bench_sglang.py --host http://localhost  --port 8000 --data-path /data00
100%|██████████████████████████████████████████| 200/200 [02:22<00:00,  1.40it/s]
Accuracy: 0.960
Invalid: 0.005
Latency: 142.941 s
Output throughput: 228.080 token/s

PP+PD:

python3 bench_sglang.py --host http://localhost  --port 8000 --data-path /data00 --num-questions 500 --parallel 10
100%|█████████████████████████████| 500/500 [02:01<00:00,  4.11it/s]
Accuracy: 0.946
Invalid: 0.012
Latency: 121.697 s
Output throughput: 653.556 token/s

@ShangmingCai

Benchmarking and Profiling

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @zhangxiaolei123456, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces crucial modifications to enable pipeline parallelism (PP) for Qwen3.5 models. The changes ensure that model layers and their corresponding weights are correctly distributed and loaded across different pipeline stages, resolving a previous limitation where Qwen3.5 did not support PP. This enhancement improves the model's scalability and efficiency in distributed training and inference environments.

Highlights

  • Pipeline Parallelism Support: Enabled pipeline parallelism for Qwen3.5 models by distributing layers across different pipeline stages, addressing the previous lack of support.
  • Conditional Weight Loading: Implemented logic to selectively load model weights based on the assigned pipeline stage, ensuring that only weights relevant to the current rank's layers are loaded.
  • Layer Initialization Refinement: Modified the make_layers function to return start_layer and end_layer for pipeline parallelism and introduced a PPMissingLayer placeholder for embedding tokens when the embedding layer is not on the current pipeline rank.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/models/qwen3_5.py
    • Imported get_layer_id utility for identifying layer indices.
    • Modified the __init__ method to assign PPMissingLayer to self.embed_tokens when pipeline parallelism is active and the embedding layer is not present on the current rank.
    • Updated the make_layers function call to return self.layers, self.start_layer, and self.end_layer, and passed pp_rank and pp_size for proper layer distribution.
    • Added conditional logic within load_weights and load_fused_expert_weights methods to skip loading weights for layers that fall outside the current pipeline stage's assigned range.
Activity
  • No human activity has occurred on this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@zhangxiaolei123456 zhangxiaolei123456 changed the title fix PP bug for Qwen3.5 fix Qwen3.5 support PP Feb 24, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for Pipeline Parallelism (PP) to Qwen3.5 models. The changes correctly handle layer instantiation for different PP ranks and skip loading weights for layers that are not on the current rank. The implementation looks solid. My only suggestion is to refactor the duplicated weight-skipping logic in the load_weights methods into a shared helper function to improve code maintainability.

Comment on lines +797 to +803
layer_id = get_layer_id(name)
if (
layer_id is not None
and hasattr(self, "start_layer")
and (layer_id < self.start_layer or layer_id >= self.end_layer)
):
continue
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic to skip loading weights for layers not on the current pipeline parallel rank is duplicated in three other load_weights methods in this file (at lines 925, 1091, and 1241). To improve maintainability and reduce code duplication, consider extracting this logic into a single helper function.

@zhangxiaolei123456 zhangxiaolei123456 changed the title fix Qwen3.5 support PP feat(Qwen3.5): hybrid linear attention support PP Feb 28, 2026
@zhangxiaolei123456 zhangxiaolei123456 changed the title feat(Qwen3.5): hybrid linear attention support PP feat(Qwen3.5): hybrid linear attention support PP+PD Mar 1, 2026
@yuan-luo
Copy link
Copy Markdown
Collaborator

yuan-luo commented Mar 2, 2026

Hi @zhangxiaolei123456 , I'm the author of #19582, when doing this PR, I was not aware that you were working on the similar work. Could we collaborate on the Qwen3.5 PP function?

@zhangxiaolei123456
Copy link
Copy Markdown
Contributor Author

Hi @zhangxiaolei123456 , I'm the author of #19582, when doing this PR, I was not aware that you were working on the similar work. Could we collaborate on the Qwen3.5 PP function?
OK,I will test your PR.

@yuan-luo
Copy link
Copy Markdown
Collaborator

yuan-luo commented Mar 3, 2026

Hi @zhangxiaolei123456 , I'm the author of #19582, when doing this PR, I was not aware that you were working on the similar work. Could we collaborate on the Qwen3.5 PP function?
OK,I will test your PR.

Thanks. Please leave message here if there's any update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants