Skip to content

[Qwen3.5] Support Qwen3.5 Pipeline Parallelism#19670

Merged
BBuf merged 1 commit intosgl-project:mainfrom
antgroup:support_qwen35_pp
Mar 7, 2026
Merged

[Qwen3.5] Support Qwen3.5 Pipeline Parallelism#19670
BBuf merged 1 commit intosgl-project:mainfrom
antgroup:support_qwen35_pp

Conversation

@yuan-luo
Copy link
Copy Markdown
Collaborator

@yuan-luo yuan-luo commented Mar 2, 2026

Motivation

To close #19500

Currently Qwen3.5 PP will crash with error. With this PR it works.

Server:

➜  /sgl-workspace python -m sglang.launch_server --model Qwen/Qwen3.5-35B-A3B --pp-size 2

gsm8k no drop.

➜  bench_script lm_eval --model local-completions --tasks gsm8k   --model_args base_url=http://localhost:30000/v1/completions,model=Qwen/Qwen3.5-35B-A3B,num_concurrent=109;
2026-03-02:08:38:36 INFO     [_cli.run:376] Selected Tasks: ['gsm8k']
2026-03-02:08:38:36 INFO     [evaluator:211] Setting random seed to 0 | Setting numpy seed to 1234 | Setting torch manual seed to 1234 | Setting fewshot manual seed to 1234
2026-03-02:08:38:36 INFO     [evaluator:236] Initializing local-completions model, with arguments: {'base_url': 'http://localhost:30000/v1/completions', 'model': 'Qwen/Qwen3.5-35B-A3B', 'num_concurrent': 109}
2026-03-02:08:38:36 INFO     [models.openai_completions:42] Remote tokenizer not supported. Using huggingface tokenizer backend.
2026-03-02:08:38:36 INFO     [models.api_models:172] Using max length 2048 - 1
2026-03-02:08:38:36 INFO     [models.api_models:193] Using tokenizer huggingface
2026-03-02:08:38:40 INFO     [tasks:700] Selected tasks:
2026-03-02:08:38:40 INFO     [tasks:691] Task: gsm8k (gsm8k/gsm8k.yaml)
2026-03-02:08:38:40 INFO     [evaluator:314] gsm8k: Using gen_kwargs: {'until': ['Question:', '</s>', '<|im_end|>'], 'do_sample': False, 'temperature': 0.0}
2026-03-02:08:38:40 INFO     [api.task:311] Building contexts for gsm8k on rank 0...
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [00:04<00:00, 294.79it/s]
2026-03-02:08:38:44 INFO     [evaluator:584] Running generate_until requests
Requesting API: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [02:02<00:00, 10.73it/s]
fatal: not a git repository (or any of the parent directories): .git
2026-03-02:08:40:56 INFO     [loggers.evaluation_tracker:316] Output path not provided, skipping saving results aggregated
local-completions ({'base_url': 'http://localhost:30000/v1/completions', 'model': 'Qwen/Qwen3.5-35B-A3B', 'num_concurrent': 109}), gen_kwargs: ({}), limit: None, num_fewshot: None, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8529|±  |0.0098|
|     |       |strict-match    |     5|exact_match|↑  |0.8355|±  |0.0102|

Modifications

There are several modifications:

  1. embed_tokens missing PPMissingLayer placeholder which makes launch server crash
  2. make_layers return a module list covering all the layers, we need to sort out the layers belonging to the current PP layer. For those not belonging to the current PP layer, it needs to use PPMissingLayer as placeholder.

Accuracy Tests

gsm8k no drop.

Benchmarking and Profiling

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses and resolves issues preventing Qwen3.5 models from utilizing Pipeline Parallelism. It introduces mechanisms to correctly manage model layers and components across different pipeline stages, ensuring that each parallel rank only initializes and processes the parts of the model relevant to its assigned segment, thereby enabling stable and functional distributed execution.

Highlights

  • Pipeline Parallelism Support: Enabled Pipeline Parallelism for Qwen3.5 models, resolving a previous crash during execution.
  • PPMissingLayer Integration: Introduced PPMissingLayer to correctly handle model components (like embed_tokens and norm) that are not present on a specific pipeline parallel rank, preventing errors.
  • Layer Distribution Logic: Implemented logic to distribute decoder layers across pipeline parallel ranks, ensuring each rank processes only its assigned subset of layers.
  • Conditional Embedding/Head Handling: Modified embedding and language model head weight management to be conditional on the pipeline parallel rank, preventing errors when these components are not on the current rank.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/models/qwen3_5.py
    • Imported PPMissingLayer for handling missing pipeline parallel components.
    • Modified the initialization of embed_tokens to use PPMissingLayer when not on the first pipeline parallel rank.
    • Added logic to calculate the start_layer and end_layer for the current pipeline parallel rank.
    • Updated the initialization of the final norm layer to use PPMissingLayer when not on the last pipeline parallel rank.
    • Adjusted the forward method's loop to iterate only over the layers assigned to the current pipeline parallel rank.
    • Updated get_embed_and_head and set_embed_and_head methods to conditionally access embed_tokens and lm_head weights based on the pipeline parallel rank.
Activity
  • No human activity (comments, reviews, or progress updates) has been recorded for this pull request.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@yuan-luo
Copy link
Copy Markdown
Collaborator Author

yuan-luo commented Mar 2, 2026

/tag-and-rerun-ci

@github-actions github-actions Bot added the run-ci label Mar 2, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for pipeline parallelism to Qwen3.5 models. The changes correctly handle pipeline stages by using PPMissingLayer for embeddings and the final normalization layer on ranks where they are not needed. The logic for accessing weights in get_embed_and_head and set_embed_and_head is also correctly updated to be pipeline-aware.

My review includes two main points. First, a high-severity issue regarding memory efficiency: all decoder layers are currently instantiated on all pipeline ranks, which can lead to unnecessary memory consumption. I've provided a suggestion to fix this by using PPMissingLayer for inactive layers. Second, a medium-severity suggestion to refactor duplicated code in Qwen3_5ForConditionalGeneration and Qwen3_5MoeForConditionalGeneration into a common base class to improve maintainability.

Comment thread python/sglang/srt/models/qwen3_5.py
Comment thread python/sglang/srt/models/qwen3_5.py Outdated
Copy link
Copy Markdown
Collaborator

@ShangmingCai ShangmingCai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can contact the author of #19254, I see some similar effort, so maybe we can converge the plan a little bit.

@yuan-luo yuan-luo force-pushed the support_qwen35_pp branch from 231f212 to 1942cc3 Compare March 3, 2026 02:16
@yuan-luo
Copy link
Copy Markdown
Collaborator Author

yuan-luo commented Mar 3, 2026

You can contact the author of #19254, I see some similar effort, so maybe we can converge the plan a little bit.

@ShangmingCai I couldn't contact @zhangxiaolei123456 directly, I left message in #19254.
Since this PR is to address #19500 independently and the scope is relatively smaller, shall we proceed to review this PR now?

@yuan-luo
Copy link
Copy Markdown
Collaborator Author

yuan-luo commented Mar 3, 2026

/rerun-failed-ci

1 similar comment
@yuan-luo
Copy link
Copy Markdown
Collaborator Author

yuan-luo commented Mar 4, 2026

/rerun-failed-ci

Comment thread test/registered/distributed/test_pp_single_node.py
Copy link
Copy Markdown
Collaborator

@BBuf BBuf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good.

Comment on lines 703 to 707
self.layers = make_layers(
config.num_hidden_layers,
get_layer,
prefix=f"{prefix}.layers",
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we change this block as well? I think maybe pass the pp size and pp rank into make_layers, we can get start_layer and end_layer without the need to call get_pp_indices separately.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactored.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per double check we can't use make_layers to generate start_layer, end_layer, it will make the result incorrect. The reason is we need to loop all layers, instead of starting from start_layer and set the missing layer accordingly inside make_layers. Change back.

Comment thread python/sglang/srt/models/qwen3_5.py Outdated
Copy link
Copy Markdown
Collaborator

@ShangmingCai ShangmingCai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Others LGTM, as long as the new test passes CI.

@yuan-luo yuan-luo force-pushed the support_qwen35_pp branch 2 times, most recently from 3f17199 to c1e4a2d Compare March 7, 2026 06:30
@yuan-luo yuan-luo force-pushed the support_qwen35_pp branch from c1e4a2d to 5a35581 Compare March 7, 2026 08:54
@yuan-luo
Copy link
Copy Markdown
Collaborator Author

yuan-luo commented Mar 7, 2026

/rerun-failed-ci

2 similar comments
@yuan-luo
Copy link
Copy Markdown
Collaborator Author

yuan-luo commented Mar 7, 2026

/rerun-failed-ci

@yuan-luo
Copy link
Copy Markdown
Collaborator Author

yuan-luo commented Mar 7, 2026

/rerun-failed-ci

@BBuf BBuf merged commit 7da590d into sgl-project:main Mar 7, 2026
228 of 254 checks passed
@yuan-luo yuan-luo deleted the support_qwen35_pp branch March 8, 2026 01:23
@hlu1 hlu1 mentioned this pull request Mar 9, 2026
17 tasks
Wangzheee pushed a commit to Wangzheee/sglang that referenced this pull request Mar 21, 2026
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
popsiclexu pushed a commit to popsiclexu/sglang that referenced this pull request Mar 25, 2026
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
ShangmingCai added a commit that referenced this pull request Mar 25, 2026
Qwen35 PP support and its consistency check are introduced in #19670, but the test turned out to be flaky on H100 and AMD, which blocks the CI.

The performance regression can not be reproduced on H20, so we need some time to dig before bringing this test back.
JustinTong0323 pushed a commit to JustinTong0323/sglang that referenced this pull request Apr 7, 2026
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
yhyang201 pushed a commit to yhyang201/sglang that referenced this pull request Apr 22, 2026
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] Qwen3.5 does not work with pipeline parallelism

4 participants