Skip to content

[Qwen3-Next] Fuse Qwen3-Next GDN's qkvz_proj and ba_proj#19321

Merged
BBuf merged 3 commits intosgl-project:mainfrom
antgroup:fuse_gdn_proj
Mar 20, 2026
Merged

[Qwen3-Next] Fuse Qwen3-Next GDN's qkvz_proj and ba_proj#19321
BBuf merged 3 commits intosgl-project:mainfrom
antgroup:fuse_gdn_proj

Conversation

@yuan-luo
Copy link
Copy Markdown
Collaborator

@yuan-luo yuan-luo commented Feb 25, 2026

Motivation

This PR is to fuse Qwen3-Next GDN's qkvz_proj and ba_proj with MergedColumnParallelLinear in order to improve performance.

TTFT speedup 2.6%. (Stably) E2E throughput increases 2.6%. (Stably in several testing)

We plan to fuse Qwen3.5 GDN's qkvz_proj and ba_proj in the follow up PR.

Server:
CUDA_VISIBLE_DEVICES=4,5,6,7 python3 -m sglang.launch_server --model Qwen/Qwen3-Next-80B-A3B-Thinking --tp 4 --dp 4 --enable-dp-attention --speculative-num-steps 3  --speculative-eagle-topk 1  --speculative-num-draft-tokens 4 --speculative-algo NEXTN --reasoning-parser deepseek-r1

Client:
python3 -m sglang.bench_serving   --backend sglang   --host 127.0.0.1 --port 30000 --dataset-name random   --random-input-len 8000 --random-output 1500 --dataset-path /data/ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 200

MAIN:
============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Max request concurrency:                 not set
Successful requests:                     200
Benchmark duration (s):                  61.54
Total input tokens:                      788704
Total input text tokens:                 788704
Total generated tokens:                  153723
Total generated tokens (retokenized):    153714
Request throughput (req/s):              3.25
Input token throughput (tok/s):          12816.74
Output token throughput (tok/s):         2498.06
Peak output token throughput (tok/s):    6013.00
Peak concurrent requests:                200
Total token throughput (tok/s):          15314.80
Concurrency:                             117.91
Accept length:                           3.83
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   36279.77
Median E2E Latency (ms):                 37102.72
P90 E2E Latency (ms):                    57863.55
P99 E2E Latency (ms):                    60756.83
---------------Time to First Token----------------
Mean TTFT (ms):                          23169.69
Median TTFT (ms):                        21970.22
P99 TTFT (ms):                           54750.62
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          17.76
Median TPOT (ms):                        17.16
P99 TPOT (ms):                           28.55
---------------Inter-Token Latency----------------
Mean ITL (ms):                           17.08
Median ITL (ms):                         5.65
P95 ITL (ms):                            86.60
P99 ITL (ms):                            113.28
Max ITL (ms):                            695.47
==================================================


PR:
============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Max request concurrency:                 not set
Successful requests:                     200
Benchmark duration (s):                  59.90
Total input tokens:                      788704
Total input text tokens:                 788704
Total generated tokens:                  153723
Total generated tokens (retokenized):    153714
Request throughput (req/s):              3.34
Input token throughput (tok/s):          13167.35
Output token throughput (tok/s):         2566.39
Peak output token throughput (tok/s):    5970.00
Peak concurrent requests:                200
Total token throughput (tok/s):          15733.74
Concurrency:                             118.04
Accept length:                           3.83
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   35351.34
Median E2E Latency (ms):                 36184.38
P90 E2E Latency (ms):                    56356.19
P99 E2E Latency (ms):                    59344.63
---------------Time to First Token----------------
Mean TTFT (ms):                          22555.29
Median TTFT (ms):                        21430.94
P99 TTFT (ms):                           53358.06
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          17.36
Median TPOT (ms):                        16.70
P99 TPOT (ms):                           28.37
---------------Inter-Token Latency----------------
Mean ITL (ms):                           16.67
Median ITL (ms):                         5.64
P95 ITL (ms):                            84.05
P99 ITL (ms):                            114.39
Max ITL (ms):                            687.11
==================================================

Modifications

Accuracy Tests

GSM8k no drop:

➜  sglang_dev2 git:(fuse_gdn_proj) ✗ lm_eval --model local-completions --tasks gsm8k   --model_args base_url=http://localhost:30000/v1/completions,model=Qwen/Qwen3-Next-80B-A3B-Instruct,num_concurrent=109;
2026-02-25:08:10:54 INFO     [_cli.run:376] Selected Tasks: ['gsm8k']
2026-02-25:08:10:54 WARNING  [evaluator:181] pretrained=None appears to be an instruct or chat variant but chat template is not applied. Recommend setting `apply_chat_template`
        (optionally `fewshot_as_multiturn`).
2026-02-25:08:10:54 INFO     [evaluator:211] Setting random seed to 0 | Setting numpy seed to 1234 | Setting torch manual seed to 1234 | Setting fewshot manual seed to 1234
2026-02-25:08:10:54 INFO     [evaluator:236] Initializing local-completions model, with arguments: {'base_url': 'http://localhost:30000/v1/completions', 'model': 'Qwen/Qwen3-Next-80B-A3B-Instruct', 'num_concurrent': 109}
2026-02-25:08:10:54 INFO     [models.openai_completions:42] Remote tokenizer not supported. Using huggingface tokenizer backend.
2026-02-25:08:10:54 INFO     [models.api_models:172] Using max length 2048 - 1
2026-02-25:08:10:54 INFO     [models.api_models:193] Using tokenizer huggingface
2026-02-25:08:10:58 INFO     [tasks:700] Selected tasks:
2026-02-25:08:10:58 INFO     [tasks:691] Task: gsm8k (gsm8k/gsm8k.yaml)
2026-02-25:08:10:58 INFO     [evaluator:314] gsm8k: Using gen_kwargs: {'until': ['Question:', '</s>', '<|im_end|>'], 'do_sample': False, 'temperature': 0.0}
2026-02-25:08:10:58 INFO     [api.task:311] Building contexts for gsm8k on rank 0...
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [00:04<00:00, 291.61it/s]
2026-02-25:08:11:02 INFO     [evaluator:584] Running generate_until requests
Requesting API: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [02:37<00:00,  8.39it/s]
2026-02-25:08:13:50 INFO     [loggers.evaluation_tracker:316] Output path not provided, skipping saving results aggregated
local-completions ({'base_url': 'http://localhost:30000/v1/completions', 'model': 'Qwen/Qwen3-Next-80B-A3B-Instruct', 'num_concurrent': 109}), gen_kwargs: ({}), limit: None, num_fewshot: None, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8650|±  |0.0094|
|     |       |strict-match    |     5|exact_match|↑  |0.8491|±  |0.0099|

Benchmarking and Profiling

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @yuan-luo, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant performance optimization for the Qwen3-Next model by integrating MergedColumnParallelLinear for its Gated Deep Norm (GDN) attention mechanism. Specifically, it combines the qkvz_proj and ba_proj layers, which streamlines operations and reduces latency. The changes involve adapting the core linear layer utilities to handle these merged structures and updating the Qwen3-Next model definition to leverage the new fused components, leading to measurable improvements in both Time To First Token and overall throughput.

Highlights

  • Performance Optimization: The pull request fuses Qwen3-Next GDN's qkvz_proj and ba_proj using MergedColumnParallelLinear to enhance performance, resulting in a 2.6% TTFT speedup and a 2.6% E2E throughput increase.
  • Weight Loader Enhancements: The weight_loader in linear.py was updated to support loaded_shard_id as a tuple, enabling proper handling of weights for merged linear layers, and _load_fused_module_from_checkpoint now accepts an optional output_sizes argument.
  • Qwen3-Next Model Adaptation: The Qwen3NextAttention class was refactored to utilize MergedColumnParallelLinear for in_proj_ba and introduced a new create_qkvz_proj method to construct in_proj_qkvz as a merged linear layer.
  • Weight Mapping Updates: The load_weights function in qwen3_next.py was extended to include correct mappings for the newly fused in_proj_qkvz and in_proj_ba parameters.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/layers/linear.py
    • Modified weight_loader to accept loaded_shard_id as a tuple and delegate to weight_loader_v2 for merged layers.
    • Updated _load_fused_module_from_checkpoint to include an output_sizes parameter.
    • Adjusted weight_loader_v2 to correctly handle tuple loaded_shard_id and pass output_sizes to the fused module loading function.
  • python/sglang/srt/models/qwen3_next.py
    • Imported MergedColumnParallelLinear for use in the model.
    • Refactored in_proj_ba in Qwen3NextAttention to use MergedColumnParallelLinear.
    • Added a new method create_qkvz_proj to Qwen3NextAttention to instantiate in_proj_qkvz as a MergedColumnParallelLinear.
    • Extended load_weights to include new parameter mappings for the fused in_proj_qkvz and in_proj_ba.
Activity
  • The author provided benchmark results demonstrating a 2.6% TTFT speedup and a 2.6% E2E throughput increase after applying the changes.
  • The pull request description includes a checklist for code formatting, unit tests, documentation, and benchmarking, none of which are marked as completed.
  • The PR outlines a standard review process involving Merge Oncalls, CODEOWNERS, and CI tests.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This PR aims to improve performance by fusing qkvz_proj and ba_proj in Qwen3-Next's Gated Delta Net (GDN) using MergedColumnParallelLinear. The changes in python/sglang/srt/layers/linear.py correctly add support for loading fused weights with tuple-based shard IDs. The fusion of ba_proj in python/sglang/srt/models/qwen3_next.py also appears correct. However, there is a critical issue in the implementation of in_proj_qkvz fusion. The create_qkvz_proj method defines a MergedColumnParallelLinear with only two output partitions, which contradicts the weight loading logic in load_weights that expects four partitions. This will lead to an IndexError at runtime. I've provided a comment with a suggested fix.

Comment thread python/sglang/srt/models/qwen3_next.py Outdated
@yuan-luo
Copy link
Copy Markdown
Collaborator Author

/tag-and-rerun-ci

@yuan-luo
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

@yuan-luo
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

@yuan-luo
Copy link
Copy Markdown
Collaborator Author

yuan-luo commented Feb 26, 2026

I found ./test/registered/models/test_qwen3_next_models_pcg.py failed in CI due to acc drop.
This PR has nothing to do with accuracy, weird. Investigating. Setting WIP for now.

@yuan-luo yuan-luo changed the title [Qwen3-Next] Fuse Qwen3-Next GDN's qkvz_proj and ba_proj [WIP][Qwen3-Next] Fuse Qwen3-Next GDN's qkvz_proj and ba_proj Feb 26, 2026
@yuan-luo yuan-luo requested a review from yizhang2077 February 27, 2026 02:23
@yuan-luo yuan-luo force-pushed the fuse_gdn_proj branch 2 times, most recently from 1f36219 to d36e7b5 Compare February 27, 2026 06:36
@yuan-luo yuan-luo marked this pull request as draft March 1, 2026 01:02
@yuan-luo yuan-luo force-pushed the fuse_gdn_proj branch 2 times, most recently from 290f9fd to 8461701 Compare March 7, 2026 13:07
@yuan-luo yuan-luo force-pushed the fuse_gdn_proj branch 4 times, most recently from ed019b7 to 4b111d2 Compare March 17, 2026 18:07
@yuan-luo
Copy link
Copy Markdown
Collaborator Author

Problem fixed. With PCG, the result is correct now.

➜  bench_script python test_openai.py
ChatCompletion(id='66f5be8c74274aa3b2e82530194de7e6', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content='Sure! Here are three countries and their capitals:\n\n1. **France** – Paris  \n2. **Japan** – Tokyo  \n3. **Brazil** – Brasília  \n\n### How I Ranked Them:\nI ranked these countries **alphabetically by country name**:\n\n- **Brazil** (B)  \n- **France** (F)  \n- **Japan** (J)  \n\nThis is a neutral, objective ranking method — no value judgments about size, population, economy, or cultural influence. Alphabetical order ensures fairness and consistency. If you’d like them ranked by population, GDP, or something else, just let me know!', refusal=None, role='assistant', annotations=None, audio=None, function_call=None, tool_calls=None, reasoning_content=None), matched_stop=151645)], created=1773771281, model='default', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=129, prompt_tokens=33, total_tokens=162, completion_tokens_details=None, prompt_tokens_details=None, reasoning_tokens=0), metadata={'weight_version': 'default'})

@yuan-luo yuan-luo marked this pull request as ready for review March 17, 2026 18:15
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@yuan-luo
Copy link
Copy Markdown
Collaborator Author

/tag-and-rerun-ci

@yuan-luo yuan-luo requested a review from kaixih March 17, 2026 18:16
@yuan-luo yuan-luo changed the title [WIP][Qwen3-Next] Fuse Qwen3-Next GDN's qkvz_proj and ba_proj [Qwen3-Next] Fuse Qwen3-Next GDN's qkvz_proj and ba_proj Mar 17, 2026
Comment thread python/sglang/srt/models/qwen3_next.py Outdated
Copy link
Copy Markdown
Collaborator

@zminglei zminglei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Yuan for this nice optimization! Could we also confirm it could still on-par or better than baseline in decode heavy cases?

@kaixih
Copy link
Copy Markdown
Collaborator

kaixih commented Mar 18, 2026

Thanks for the PR! A couple of questions to align my understanding:

For Qwen3-Next: The HF checkpoint already stores in_proj_qkvz.weight and in_proj_ba.weight as fused tensors, and the original ColumnParallelLinear was already doing a single GEMM for each. Changing to MergedColumnParallelLinear doesn't change the forward pass (same GEMM) or the weight loading behavior (fused checkpoint hits the loaded_shard_id=None path in _make_packed_weight_loader, which does a direct TP slice + copy, same as before). So I'm not sure where the 2.6% TTFT/throughput speedup comes from: could you clarify what's driving the improvement here?

For Qwen3.5: The HF checkpoint stores weights split as in_proj_qkv.weight, in_proj_z.weight, in_proj_b.weight, in_proj_a.weight, and the current qwen3_5.py loads them into 4 separate layers doing 4 GEMMs. The follow-up PR you mentioned would fuse these into 2 MergedColumnParallelLinear layers (in_proj_qkvz and in_proj_ba), using the new tuple shard_id (0,1,2) to load in_proj_qkv.weight into slots 0,1,2 and in_proj_z.weight into slot 3 of the merged in_proj_qkvz parameter, and similarly loading in_proj_b.weight and in_proj_a.weight into the merged in_proj_ba parameter: reducing to 2 GEMMs. Could you confirm this is the intended design for the follow-up?

@yuan-luo
Copy link
Copy Markdown
Collaborator Author

/tag-and-rerun-ci

@edwingao28
Copy link
Copy Markdown
Contributor

edwingao28 commented Mar 19, 2026

@zminglei @yuan-luo I tested with A100 tp = 4 for decode-heavy scenario. Results from main branch and PR are 3 runs averaged (first run served as warmup and excluded).


Median TTFT improved 9.2%, Median E2E improved 1.8%

Server:

python3 -m sglang.launch_server \\
  --model /workspace/models/qwen3-next-80b-instruct \\
  --tp-size 4 \\
  --chunked-prefill-size 2048 \\
  --mamba-scheduler-strategy extra_buffer \\
  --mamba-track-interval 128 \\
  --port 30000

Client:

python3 -m sglang.bench_serving --backend sglang \\
  --host 127.0.0.1 --port 30000 --dataset-name random \\
  --random-input-len 512 --random-output-len 4096 --num-prompts 200
Metric PR Main Change
Benchmark duration (s) 124.63 126.13 -1.2%
Request throughput (req/s) 1.60 1.59 +0.6%
Input token throughput (tok/s) 419.77 414.80 +1.2%
Output token throughput (tok/s) 3260.17 3221.55 +1.2%
Peak output token throughput (tok/s) 4669.00 4791.33 -2.5%
Total token throughput (tok/s) 3679.93 3636.35 +1.2%
End-to-End Latency      
Mean E2E Latency (ms) 72847.32 74072.81 -1.7%
Median E2E Latency (ms) 76746.49 78137.16 -1.8%
P90 E2E Latency (ms) 115137.39 116613.65 -1.3%
P99 E2E Latency (ms) 121034.95 122515.38 -1.2%
Time to First Token      
Mean TTFT (ms) 4719.11 4852.31 -2.7%
Median TTFT (ms) 876.80 965.52 -9.2%
P99 TTFT (ms) 35931.79 36896.83 -2.6%
Time per Output Token      
Mean TPOT (ms) 35.47 36.13 -1.8%
Median TPOT (ms) 36.02 36.65 -1.7%
P99 TPOT (ms) 40.98 42.50 -3.6%
Inter-Token Latency      
Mean ITL (ms) 33.55 34.09 -1.6%
Median ITL (ms) 33.40 33.69 -0.9%
P95 ITL (ms) 41.39 41.60 -0.5%
P99 ITL (ms) 135.44 133.53 +1.4%
Max ITL (ms) 575.22 578.54 -0.6%

@edwingao28
Copy link
Copy Markdown
Contributor

edwingao28 commented Mar 19, 2026

Tested with GSM8K few shot test with 0.955 accuracy and no regression observed.

(venv) root@a202db6cd5da:/workspace/sglang# python3 -m sglang.test.few_shot_gsm8k --host http://127.0.0.1/ --port 21000
/workspace/sglang/python/sglang/test/few_shot_gsm8k.py:54: DeprecationWarning: Including the scheme in --host ('http://127.0.0.1/') is deprecated. Pass just the hostname (e.g. '127.0.0.1') instead.
  set_default_backend(RuntimeEndpoint(normalize_base_url(args.host, args.port)))
Downloading from https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl to /tmp/test.jsonl
/tmp/test.jsonl: 732kB [00:00, 18.9MB/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:33<00:00,  5.96it/s]
Accuracy: 0.955
Invalid: 0.000
Latency: 35.254 s
Output throughput: 917.448 token/s
(venv) root@a202db6cd5da:/workspace/sglang# git branch
  main
  pcg/cicd_testing
  test/fuse_gdn_proj_latest
* test/fuse_verify
(venv) root@a202db6cd5da:/workspace/sglang#

@yuan-luo
Copy link
Copy Markdown
Collaborator Author

yuan-luo commented Mar 19, 2026

Thanks for the PR! A couple of questions to align my understanding:

For Qwen3-Next: The HF checkpoint already stores in_proj_qkvz.weight and in_proj_ba.weight as fused tensors, and the original ColumnParallelLinear was already doing a single GEMM for each. Changing to MergedColumnParallelLinear doesn't change the forward pass (same GEMM) or the weight loading behavior (fused checkpoint hits the loaded_shard_id=None path in _make_packed_weight_loader, which does a direct TP slice + copy, same as before). So I'm not sure where the 2.6% TTFT/throughput speedup comes from: could you clarify what's driving the improvement here?

For Qwen3.5: The HF checkpoint stores weights split as in_proj_qkv.weight, in_proj_z.weight, in_proj_b.weight, in_proj_a.weight, and the current qwen3_5.py loads them into 4 separate layers doing 4 GEMMs. The follow-up PR you mentioned would fuse these into 2 MergedColumnParallelLinear layers (in_proj_qkvz and in_proj_ba), using the new tuple shard_id (0,1,2) to load in_proj_qkv.weight into slots 0,1,2 and in_proj_z.weight into slot 3 of the merged in_proj_qkvz parameter, and similarly loading in_proj_b.weight and in_proj_a.weight into the merged in_proj_ba parameter: reducing to 2 GEMMs. Could you confirm this is the intended design for the follow-up?

@kaixih Agree with you. This PR's core value is to introduce tuple shard_id for linear framework to make Qwen3.5 fuse 4 GEMM into 2 in the next step possible.

@yuan-luo
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

@BBuf BBuf merged commit d9794ef into sgl-project:main Mar 20, 2026
319 of 359 checks passed
@yuan-luo yuan-luo deleted the fuse_gdn_proj branch March 20, 2026 01:30
@yuan-luo
Copy link
Copy Markdown
Collaborator Author

yuan-luo commented Mar 20, 2026

Thanks for the PR! A couple of questions to align my understanding:

For Qwen3-Next: The HF checkpoint already stores in_proj_qkvz.weight and in_proj_ba.weight as fused tensors, and the original ColumnParallelLinear was already doing a single GEMM for each. Changing to MergedColumnParallelLinear doesn't change the forward pass (same GEMM) or the weight loading behavior (fused checkpoint hits the loaded_shard_id=None path in _make_packed_weight_loader, which does a direct TP slice + copy, same as before). So I'm not sure where the 2.6% TTFT/throughput speedup comes from: could you clarify what's driving the improvement here?

For Qwen3.5: The HF checkpoint stores weights split as in_proj_qkv.weight, in_proj_z.weight, in_proj_b.weight, in_proj_a.weight, and the current qwen3_5.py loads them into 4 separate layers doing 4 GEMMs. The follow-up PR you mentioned would fuse these into 2 MergedColumnParallelLinear layers (in_proj_qkvz and in_proj_ba), using the new tuple shard_id (0,1,2) to load in_proj_qkv.weight into slots 0,1,2 and in_proj_z.weight into slot 3 of the merged in_proj_qkvz parameter, and similarly loading in_proj_b.weight and in_proj_a.weight into the merged in_proj_ba parameter: reducing to 2 GEMMs. Could you confirm this is the intended design for the follow-up?

@kaixih I addressed the second point in #21019. Could you please help to review it? Thanks.

Wangzheee pushed a commit to Wangzheee/sglang that referenced this pull request Mar 21, 2026
…#19321)

Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
0-693 pushed a commit to 0-693/sglang that referenced this pull request Mar 25, 2026
…#19321)

Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
dutsc pushed a commit to dutsc/sglang that referenced this pull request Mar 30, 2026
…#19321)

Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
JustinTong0323 pushed a commit to JustinTong0323/sglang that referenced this pull request Apr 7, 2026
…#19321)

Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
yhyang201 pushed a commit to yhyang201/sglang that referenced this pull request Apr 22, 2026
…#19321)

Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants