[Qwen3.5] Fuse split/reshape/cat ops in GDN projection with Triton kernel#21019
[Qwen3.5] Fuse split/reshape/cat ops in GDN projection with Triton kernel#21019BBuf merged 3 commits intosgl-project:mainfrom
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly optimizes the Qwen3.5 Gated Delta Net by fusing several projection operations into a single, efficient Triton kernel. This fusion reduces computational overhead and memory footprint by minimizing kernel launches and intermediate tensor allocations, leading to improved inference performance, with benchmarks showing a +6.5% increase in output throughput and notable reductions in Time Per Output Token (TPOT) and Time To First Token (TTFT). The changes also include refactoring of projection layers and an enhanced weight loading mechanism to support both fused and split checkpoint formats, ensuring compatibility and maintainability. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
|
/tag-and-rerun-ci again |
|
Warning Gemini encountered an error creating the review. You can try again by commenting |
c149767 to
294ab57
Compare
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
|
/gemini review |
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
8aaf698 to
ffec15a
Compare
|
could you paste FP8 test results here? |
It encountered some error. Investigating. |
|
The reason is fp8 and unquant models use different weight parameters. To be more specific, When running self.in_proj_qkvz.weight.weight_loader = self._make_packed_weight_loader(), for fp8 ModelWeightParameter, weight_loader is a read-only property, which has no setter and leads error in assignment. Fixing in progress. |
|
FP8 problem fixed. |
ffec15a to
77cec9e
Compare
|
Thanks @yuan-luo taking a look |
|
The GEMM fusion follows the same approach as #19321 for Qwen3-Next. |
jasperjiaguo
left a comment
There was a problem hiding this comment.
Yes approach lgtm. I will keep a separate tab on the small model perf.
77cec9e to
1b0fa5f
Compare
…rnel (sgl-project#21019) Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
…rnel (sgl-project#21019) Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
Fuse split/reshape/cat ops in GDN projection, adapted for contiguous layout.
… contiguous kernel
…sed split/reshape/cat ops in gdn Adds FP8 quantization support for the fused GDN projection layers: - _override_weight_loader: robust loader override for FP8/quantized params - _bind_packed_weight_loaders: covers weight, weight_scale_inv, weight_scale, input_scale - _get_split_sizes_for_param: handles BlockQuantScaleParameter and PerTensorScaleParameter - Updated _make_packed_weight_loader to support FP8 scale parameters
- Cherry-pick PR sgl-project#21019: Fuse GDN split/reshape/cat ops with FP8/BF16 quant support - Add BF16 qkv z b a fusion and PTPC quant config
- Cherry-pick PR sgl-project#21019: Fuse GDN split/reshape/cat ops with FP8/BF16 quant support - Add BF16 qkv z b a fusion and PTPC quant config
- Cherry-pick PR sgl-project#21019: Fuse GDN split/reshape/cat ops with FP8/BF16 quant support - Add BF16 qkv z b a fusion and PTPC quant config
- Cherry-pick PR sgl-project#21019: Fuse GDN split/reshape/cat ops with FP8/BF16 quant support - Add BF16 qkv z b a fusion and PTPC quant config
- Cherry-pick PR sgl-project#21019: Fuse GDN split/reshape/cat ops with FP8/BF16 quant support - Add BF16 qkv z b a fusion and PTPC quant config
- Cherry-pick PR sgl-project#21019 load weight func - Add BF16 qkv z b a fusion and PTPC quant config
- Cherry-pick PR sgl-project#21019 load weight func - Add BF16 qkv z b a fusion and PTPC quant config
…rnel (sgl-project#21019) Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
…rnel (sgl-project#21019) Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
Motivation
In PR #19321 we fused Qwen3-Next GDN's qkvz_proj and ba_proj. This PR is a follow up. The background that Qwen3-Next and Qwen3.5's checkpoint layout are different.
Qwen3-Next weight loading path
The Qwen3-Next checkpoint directly stores the fused in_proj_qkvz weight (loaded_shard_id=None). This weight is already in the interleaved layout.
During loading, it goes through contiguous TP slice case, so the interleaved layout is preserved. As a result, the matmul output is also interleaved, and the Triton kernel reads it as interleaved data.
Qwen3.5 weight loading path
The Qwen3.5 checkpoint stores in_proj_qkv and in_proj_z separately. They are mapped through stacked_params_mapping with shard_id=(0,1,2) for q,k,v and shard_id=3 for z.
During loading, it goes through a different case, where MergedColumnParallelLinear.weight_loader places q, k, and v into contiguous regions according to output_sizes. Therefore, the matmul output becomes contiguous, and a new Triton kernel is needed to read from contiguous positions.
In summary, this PR fuses the split → reshape → cat operations in Qwen3_5GatedDeltaNet into a single Triton kernel (fused_qkvzba_split_reshape_cat), eliminating multiple kernel launches and intermediate tensor allocations during both prefill and decode. More details are in the following chapter.
Modifications
fused_qkvzba_split_reshape_cat_contiguous, to fusesplit,reshape, andcatoperations within the Qwen3.5 Gated Delta Net (GDN) projection, reducing kernel launches and intermediate memory allocations.in_proj_qkv,in_proj_z,in_proj_b, andin_proj_aprojection layers into two fused layers:in_proj_qkvzandin_proj_ba._make_packed_weight_loaderto correctly handle weight loading for both fused (packed) and split checkpoint formats, ensuring proper parameter initialization.Accuracy Tests
GSM8K
Main:
PR:
LLM result has no problem.
Benchmarking and Profiling
H200
Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci