feat(Qwen3.5): hybrid linear attention support PP+PD#19254
feat(Qwen3.5): hybrid linear attention support PP+PD#19254zhangxiaolei123456 wants to merge 22 commits intosgl-project:mainfrom
Conversation
Summary of ChangesHello @zhangxiaolei123456, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces crucial modifications to enable pipeline parallelism (PP) for Qwen3.5 models. The changes ensure that model layers and their corresponding weights are correctly distributed and loaded across different pipeline stages, resolving a previous limitation where Qwen3.5 did not support PP. This enhancement improves the model's scalability and efficiency in distributed training and inference environments. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request adds support for Pipeline Parallelism (PP) to Qwen3.5 models. The changes correctly handle layer instantiation for different PP ranks and skip loading weights for layers that are not on the current rank. The implementation looks solid. My only suggestion is to refactor the duplicated weight-skipping logic in the load_weights methods into a shared helper function to improve code maintainability.
| layer_id = get_layer_id(name) | ||
| if ( | ||
| layer_id is not None | ||
| and hasattr(self, "start_layer") | ||
| and (layer_id < self.start_layer or layer_id >= self.end_layer) | ||
| ): | ||
| continue |
There was a problem hiding this comment.
This logic to skip loading weights for layers not on the current pipeline parallel rank is duplicated in three other load_weights methods in this file (at lines 925, 1091, and 1241). To improve maintainability and reduce code duplication, consider extracting this logic into a single helper function.
|
Hi @zhangxiaolei123456 , I'm the author of #19582, when doing this PR, I was not aware that you were working on the similar work. Could we collaborate on the Qwen3.5 PP function? |
|
Thanks. Please leave message here if there's any update. |
Motivation
hybrid linear attention do not support PP and PD
PP:
`GLOO_SOCKET_IFNAME=eth0 NCCL_MIN_NCHANNELS=24 NCCL_IB_QPS_PER_CONNECTION=8 SGLANG_PP_LAYER_PARTITION="15,15,15,15" python -m sglang.launch_server --model-path /data00/models/Qwen3.5-397B-A17B-FP8 --port 8000 --tp-size 2 --pp-size 4 --mem-fraction-static 0.85 --reasoning-parser qwen3 --tool-call-parser qwen3_coder --mamba-ssm-dtype float16 --kv-cache-dtype fp8_e4m3 --host 0.0.0.0 --port 8000 --disable-radix-cache --pp-async-batch-depth 1 --pp-max-micro-batch-size 1 --max-running-requests 128 --chunked-prefill-size 4096 --max-prefill-tokens 16384 --page-size 64`PP+PD:
Prefill:
`GLOO_SOCKET_IFNAME=eth0 NCCL_MIN_NCHANNELS=24 NCCL_IB_QPS_PER_CONNECTION=8 SGLANG_PP_LAYER_PARTITION="15,15,15,15" python -m sglang.launch_server --model-path /data00/models/Qwen3.5-397B-A17B-FP8 --port 8000 --tp-size 2 --pp-size 4 --mem-fraction-static 0.85 --reasoning-parser qwen3 --tool-call-parser qwen3_coder --mamba-ssm-dtype float16 --kv-cache-dtype fp8_e4m3 --disaggregation-mode prefill --disaggregation-ib-device "mlx5_1,mlx5_2,mlx5_3,mlx5_4" --host 0.0.0.0 --port 30300 --disable-radix-cache --pp-async-batch-depth 1 --pp-max-micro-batch-size 1 --max-running-requests 128 --chunked-prefill-size 8192 --max-prefill-tokens 16384 --page-size 64`Decode
`GLOO_SOCKET_IFNAME=eth0 NCCL_MIN_NCHANNELS=24 NCCL_IB_QPS_PER_CONNECTION=8 SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK=128 python -m sglang.launch_server --model-path /data00/models/Qwen3.5-397B-A17B-FP8 --port 8000 --tp-size 8 --ep-size 8 --mem-fraction-static 0.75 --context-length 131072 --reasoning-parser qwen3 --tool-call-parser qwen3_coder --cuda-graph-bs 1 8 16 32 64 --max-running-requests 256 --enable-dp-attention --dp-size 4 --mamba-ssm-dtype float16 --kv-cache-dtype fp8_e4m3 --disaggregation-mode decode --disaggregation-ib-device "mlx5_1,mlx5_2,mlx5_3,mlx5_4" --moe-runner-backend deep_gemm --moe-a2a-backend deepep --deepep-mode low_latency --host 0.0.0.0 --port 30300 --speculative-algo EAGLE --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4 --enable-metrics --disable-radix-cache --page-size 64`Modifications
Performance improve 40%
Accuracy Tests
GSM8K
PP:
PP+PD:
@ShangmingCai
Benchmarking and Profiling
Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci