Skip to content

[GDN] Support GDN packed decode#20627

Merged
ispobock merged 3 commits intosgl-project:mainfrom
antgroup:support_gdn_packed_decode
Mar 18, 2026
Merged

[GDN] Support GDN packed decode#20627
ispobock merged 3 commits intosgl-project:mainfrom
antgroup:support_gdn_packed_decode

Conversation

@yuan-luo
Copy link
Copy Markdown
Collaborator

@yuan-luo yuan-luo commented Mar 15, 2026

Motivation

This PR optimized GDN decode path with packed decode. Here goes the details.

Current decode path

Has 6 steps, with lots of overhead in memory copy and small kernel launch/compute.

mixed_qkv [bs, qkv_dim]
    │
    ▼ causal_conv1d_update()
mixed_qkv [bs, qkv_dim]           ← conv1d updated
    │
    ▼ torch.split()                ←  Overhead 1: split creates non-contiguous view
(q_flat, k_flat, v_flat)
    │
    ▼ .view() + (implicit .contiguous()) ←  Overhead 2: memory copy
q [1, bs, H, K]
k [1, bs, H, K]
v [1, bs, HV, V]
    │
    ▼ fused_sigmoid_gating_delta_rule_update()
      Internal:
      (a) Independent kernel computes g = -exp(A_log)*softplus(a+dt_bias) ←  Overhead 3
      (b) Independent kernel computes beta = sigmoid(b) ←  Overhead 4
      (c) Recursion update kernel ← Core computation
    │
    ▼
output [1, bs, HV, V]

Packed decode path

Refactored to 3 steps with single kernel handling qkv, gate/beta compute as well as output write.

mixed_qkv [bs, qkv_dim]
    │
    ▼ causal_conv1d_update()
mixed_qkv [bs, qkv_dim]
    │
    ▼ fused_recurrent_gated_delta_rule_packed_decode()  ← Single kernel completes everything
      Internal:
      (a) Directly read q/k/v from packed layout via pointer arithmetic (zero copy)
      (b) Compute g, beta within registers
      (c) Recursion update
      (d) Directly write to output buffer
    │
    ▼
output [bs, 1, HV, V]  → transpose → [1, bs, HV, V]

Explanation

  • B — Batch Size. The number of sequences being decoded concurrently. More concurrent requests means a larger B.
  • H — num_q_heads / num_k_heads. The number of Q and K heads. For Qwen3.5-35B-A3B, the full count is 16; with TP=2, each GPU gets 8.
  • HV — num_v_heads. The number of V (Value) heads. This is where GDN differs from standard Attention: it uses an asymmetric head count design similar to GQA, where the V head count differs from the QK head count. Qwen3.5-35B-A3B has 32 V heads — twice the QK head count. With TP=2, each GPU gets 16.
  • K — head_k_dim. The dimension of each Q/K head, which is 128 here.
  • V — head_v_dim. The dimension of each V head, also 128.

The reason HV is listed separately is that in GDN's recurrent state (SSM state), the shape is [HV, V, K], and the gating parameters a and b have shape [B, HV] — they all follow the V head count, not the QK head count. Inside the packed decode kernel, both H and HV must be handled simultaneously to split mixed_qkv:

mixed_qkv: [B,    2*H*K    +    HV*V]
                 ─────          ─────
                 Q and K       V part
                 uses H        uses HV
                 heads          heads
image
➜  sglang_dev2 git:(support_gdn_packed_decode) ✗ CUDA_VISIBLE_DEVICES=6,7 python bench_gdn_decode.py
Device: NVIDIA H200  (SM 90)
======================================================================
Correctness: Baseline GDN Decode vs Packed GDN Decode
======================================================================
  [PASS] B=   1 H= 8 HV=16 K=128 V=128 pool=  32
  [PASS] B=   4 H= 8 HV=16 K=128 V=128 pool=  32
  [PASS] B=  16 H= 8 HV=16 K=128 V=128 pool=  64
  [PASS] B=  32 H= 8 HV=16 K=128 V=128 pool= 128
  [PASS] B=  64 H= 8 HV=16 K=128 V=128 pool= 128
  [PASS] B= 128 H= 8 HV=16 K=128 V=128 pool= 256
  [PASS] B= 256 H= 8 HV=16 K=128 V=128 pool= 512
  [PASS] B=   1 H=16 HV=32 K=128 V=128 pool=  32
  [PASS] B=  32 H=16 HV=32 K=128 V=128 pool= 128
  [PASS] B=  64 H=16 HV=32 K=128 V=128 pool= 128
  [PASS] B=  32 H=16 HV=16 K=128 V=128 pool= 128
  [PASS] B=  64 H=16 HV=16 K=128 V=128 pool= 128
  [PASS] B=  32 H= 8 HV=16 K=128 V=128 pool= 128
  [PASS] B=   1 H= 8 HV=16 K=128 V=128 pool=  32
  [PASS] B=   2 H= 8 HV=16 K=128 V=128 pool=  32

  PAD_SLOT_ID test (indices with -1):
  [PASS] PAD_SLOT_ID=-1 handling

ALL PASSED.

=====================================================================================
Benchmark: Baseline GDN Decode vs Packed GDN Decode
=====================================================================================
  Config: K=128, V=128, pool_size=512, dtype=torch.bfloat16
      B    H   HV    K    V |  base (μs) | packed (μs) |  speedup | saved (μs)
  ---------------------------------------------------------------------------
      1    8   16  128  128 |       20.3 |        7.8 |    2.59x |     +12.4
      2    8   16  128  128 |       19.4 |        8.1 |    2.40x |     +11.3
      4    8   16  128  128 |       19.9 |        8.4 |    2.36x |     +11.5
      8    8   16  128  128 |       20.0 |        9.3 |    2.14x |     +10.7
     16    8   16  128  128 |       20.5 |       11.8 |    1.73x |      +8.7
     32    8   16  128  128 |       21.5 |       17.8 |    1.21x |      +3.7
     64    8   16  128  128 |       30.8 |       28.8 |    1.07x |      +2.0
    128    8   16  128  128 |       55.4 |       48.1 |    1.15x |      +7.3
    256    8   16  128  128 |      103.6 |       87.3 |    1.19x |     +16.3
    512    8   16  128  128 |      198.9 |      165.4 |    1.20x |     +33.5
      1   16   32  128  128 |       19.3 |        8.0 |    2.41x |     +11.3
      8   16   32  128  128 |       20.0 |       11.6 |    1.71x |      +8.3
     32   16   32  128  128 |       30.9 |       29.1 |    1.06x |      +1.9
     64   16   32  128  128 |       55.2 |       48.0 |    1.15x |      +7.2
    128   16   32  128  128 |      103.3 |       87.6 |    1.18x |     +15.6
    256   16   32  128  128 |      196.4 |      165.4 |    1.19x |     +31.1

Modifications

Accuracy Tests

GSM8K no drop:
➜ sglang_dev2 git:(support_gdn_packed_decode) ✗ lm_eval --model local-completions --tasks gsm8k --model_args base_url=http://localhost:30000/v1/completions,model=Qwen/Qwen3.5-35B-A3B,num_concurrent=109;

2026-03-15:12:54:49 INFO [_cli.run:376] Selected Tasks: ['gsm8k']
2026-03-15:12:54:49 INFO [evaluator:211] Setting random seed to 0 | Setting numpy seed to 1234 | Setting torch manual seed to 1234 | Setting fewshot manual seed to 1234
2026-03-15:12:54:49 INFO [evaluator:236] Initializing local-completions model, with arguments: {'base_url': 'http://localhost:30000/v1/completions', 'model': 'Qwen/Qwen3.5-35B-A3B', 'num_concurrent': 109}
2026-03-15:12:54:49 INFO [models.openai_completions:42] Remote tokenizer not supported. Using huggingface tokenizer backend.
2026-03-15:12:54:49 INFO [models.api_models:172] Using max length 2048 - 1
2026-03-15:12:54:49 INFO [models.api_models:193] Using tokenizer huggingface
2026-03-15:12:54:53 INFO [tasks:700] Selected tasks:
2026-03-15:12:54:53 INFO [tasks:691] Task: gsm8k (gsm8k/gsm8k.yaml)
2026-03-15:12:54:53 INFO [evaluator:314] gsm8k: Using gen_kwargs: {'until': ['Question:', '', '<|im_end|>'], 'do_sample': False, 'temperature': 0.0}
2026-03-15:12:54:53 INFO [api.task:311] Building contexts for gsm8k on rank 0...
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [00:04<00:00, 299.62it/s]
2026-03-15:12:54:57 INFO [evaluator:584] Running generate_until requests
Requesting API: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [02:15<00:00, 9.71it/s]
2026-03-15:12:57:22 INFO [loggers.evaluation_tracker:316] Output path not provided, skipping saving results aggregated
local-completions ({'base_url': 'http://localhost:30000/v1/completions', 'model': 'Qwen/Qwen3.5-35B-A3B', 'num_concurrent': 109}), gen_kwargs: ({}), limit: None, num_fewshot: None, batch_size: 1

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.8484 ± 0.0099
strict-match 5 exact_match 0.8362 ± 0.0102

Benchmarking and Profiling

Item MAIN Packed PR Speedup
Output throughput (tok/s) 1,474 1,738 +17.9%
Mean TPOT (ms) 85.96 72.88 -15.2%
Median TPOT (ms) 89.44 75.38 -15.7%
P99 TPOT (ms) 133.62 115.63 -13.5%
Mean E2E (ms) 8,468 7,150 -15.6%
Mean ITL (ms) 83.46 70.35 -15.7%
Mean TTFT (ms) 205.02 185.24 -9.6%
Duration (s) 54.27 46.03 -15.2%
Request throughput (req/s) 14.74 17.38 +17.9%
CUDA_VISIBLE_DEVICES=1,2 python3 -m sglang.launch_server \
  --model Qwen/Qwen3.5-35B-A3B \
  --tp 2 \
  --port 30000 \
  --max-running-requests 512

python3 -m sglang.bench_serving \
  --backend sglang --port 30000 \
  --dataset-name random \
  --random-input-len 128 \
  --random-output-len 200 \
  --num-prompts 800 \
  --max-concurrency 128

MAIN:
============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Max request concurrency:                 128
Successful requests:                     800
Benchmark duration (s):                  54.27
Total input tokens:                      50469
Total input text tokens:                 50469
Total generated tokens:                  80008
Total generated tokens (retokenized):    79925
Request throughput (req/s):              14.74
Input token throughput (tok/s):          929.89
Output token throughput (tok/s):         1474.15
Peak output token throughput (tok/s):    5378.00
Peak concurrent requests:                158
Total token throughput (tok/s):          2404.04
Concurrency:                             124.82
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   8467.95
Median E2E Latency (ms):                 7764.62
P90 E2E Latency (ms):                    16216.21
P99 E2E Latency (ms):                    18987.76
---------------Time to First Token----------------
Mean TTFT (ms):                          205.02
Median TTFT (ms):                        159.36
P99 TTFT (ms):                           533.69
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          85.96
Median TPOT (ms):                        89.44
P99 TPOT (ms):                           133.62
---------------Inter-Token Latency----------------
Mean ITL (ms):                           83.46
Median ITL (ms):                         18.38
P95 ITL (ms):                            312.33
P99 ITL (ms):                            337.89
Max ITL (ms):                            583.09
==================================================

PR:
============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Max request concurrency:                 128
Successful requests:                     800
Benchmark duration (s):                  46.03
Total input tokens:                      50469
Total input text tokens:                 50469
Total generated tokens:                  80008
Total generated tokens (retokenized):    79916
Request throughput (req/s):              17.38
Input token throughput (tok/s):          1096.38
Output token throughput (tok/s):         1738.07
Peak output token throughput (tok/s):    5197.00
Peak concurrent requests:                162
Total token throughput (tok/s):          2834.45
Concurrency:                             124.26
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   7150.29
Median E2E Latency (ms):                 6585.36
P90 E2E Latency (ms):                    13762.13
P99 E2E Latency (ms):                    15910.20
---------------Time to First Token----------------
Mean TTFT (ms):                          185.24
Median TTFT (ms):                        153.87
P99 TTFT (ms):                           478.87
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          72.88
Median TPOT (ms):                        75.38
P99 TPOT (ms):                           115.63
---------------Inter-Token Latency----------------
Mean ITL (ms):                           70.35
Median ITL (ms):                         17.65
P95 ITL (ms):                            242.23
P99 ITL (ms):                            328.77
Max ITL (ms):                            689.37
==================================================

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant optimization to the GDN decode process by refactoring the existing multi-step decode path into a more efficient, fused 'packed decode' operation. The change aims to reduce overhead associated with memory copies and multiple kernel launches, leading to substantial improvements in inference speed and token throughput. The new approach consolidates several operations into a single Triton kernel, streamlining the computation for better performance.

Highlights

  • Performance Optimization: Optimized the GDN (Gated Delta Network) decode path by introducing a 'packed decode' mechanism, reducing the computational steps from six to three.
  • Fused Kernel Implementation: Implemented a new Triton-based kernel that fuses QKV extraction, gating, and recurrent updates into a single operation, eliminating intermediate tensor copies and multiple kernel launches.
  • Significant Speedup: Achieved a notable performance improvement, with output throughput increasing by 17.9% and mean Time Per Output Token (TPOT) decreasing by 15.2%.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/layers/attention/fla/fused_recurrent.py
    • Introduced fused_recurrent_gated_delta_rule_packed_decode_kernel, a new Triton JIT-compiled kernel for optimized GDN decode.
    • Added fused_recurrent_gated_delta_rule_packed_decode, a Python function that wraps the Triton kernel, handling input validation and launching the kernel.
  • python/sglang/srt/layers/attention/linear/gdn_backend.py
    • Updated GDNKernelDispatcher to include a supports_packed_decode attribute and a new packed_decode method.
    • Modified forward_decode to conditionally use the new packed_decode path if supported by the kernel dispatcher, bypassing the original multi-step decode process.
  • python/sglang/srt/layers/attention/linear/kernels/gdn_triton.py
    • Imported fused_recurrent_gated_delta_rule_packed_decode for use in the Triton kernel.
    • Set supports_packed_decode to True for TritonGDNKernel (when not on CPU/NPU).
    • Implemented the packed_decode method, which prepares inputs and calls the new fused Triton kernel.
Activity
  • No human activity (comments, reviews) has been recorded on this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@yuan-luo yuan-luo requested review from BBuf, kaixih and xutizhou March 15, 2026 12:51
@yuan-luo
Copy link
Copy Markdown
Collaborator Author

/tag-and-rerun-ci

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant optimization for the GDN decode path by implementing a packed decode kernel. The new fused Triton kernel avoids overhead from memory copies and multiple kernel launches, resulting in a notable performance improvement as demonstrated by the benchmarks. The changes are well-structured, with clear feature flagging and a fallback to the old path, ensuring robustness. The code is adapted from the vllm project, and proper attribution is given. I have a couple of minor suggestions to improve code readability and maintainability.

Comment thread python/sglang/srt/layers/attention/fla/fused_recurrent.py Outdated
Comment thread python/sglang/srt/layers/attention/linear/gdn_backend.py
@yuan-luo yuan-luo force-pushed the support_gdn_packed_decode branch from b4659f2 to f4346f0 Compare March 15, 2026 13:13
@yuan-luo
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

1 similar comment
@yuan-luo
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

Comment thread python/sglang/srt/layers/attention/linear/gdn_backend.py Outdated
@ispobock ispobock merged commit 9c87e13 into sgl-project:main Mar 18, 2026
252 of 275 checks passed
@yuan-luo yuan-luo deleted the support_gdn_packed_decode branch March 19, 2026 14:21
Wangzheee pushed a commit to Wangzheee/sglang that referenced this pull request Mar 21, 2026
0-693 pushed a commit to 0-693/sglang that referenced this pull request Mar 25, 2026
dutsc pushed a commit to dutsc/sglang that referenced this pull request Mar 30, 2026
JustinTong0323 pushed a commit to JustinTong0323/sglang that referenced this pull request Apr 7, 2026
yhyang201 pushed a commit to yhyang201/sglang that referenced this pull request Apr 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants