Skip to content

[GDN] Change Attention State Layout from [N, HV, K, V] to [N, HV, V, K]#20283

Merged
ispobock merged 5 commits intosgl-project:mainfrom
antgroup:fla_swap_kv
Mar 12, 2026
Merged

[GDN] Change Attention State Layout from [N, HV, K, V] to [N, HV, V, K]#20283
ispobock merged 5 commits intosgl-project:mainfrom
antgroup:fla_swap_kv

Conversation

@yuan-luo
Copy link
Copy Markdown
Collaborator

@yuan-luo yuan-luo commented Mar 10, 2026

Motivation

In order to improve memory access pattern and throughput, this PR transpose the recurrent state memory layout in GDN attention from [N, HV, K, V] to [N, HV, V, K].

KV swap aligns the long edge (K dimension) of the state tile to the memory contiguous direction, significantly improving the efficiency of GPU's coalesced memory access, allowing the GPU to fetch more effective data with each memory access. This effect is also noticeable in the decode scenario (when BV is limited to 8).

Original [K, V] after Swap [V, K]
Tile Shape (decode) [256, 8] vs [8, 256]
Number of contiguous elements per row 8 vs 256
Number of rows 256 vs 8

Both GDN and KDA's SSM are adapted to VK layout. This change covers all the decode/extend/target_verify APIs.

Modifications

Accuracy Tests

gpqa no drops:

➜  bench_script python3 -m sglang.test.run_eval --port 30000 --eval-name gpqa --num-examples 198 --max-tokens 4096 --repeat 8
ChatCompletionSampler initialized with self.system_message=None self.temperature=0.0 self.max_tokens=4096 self.reasoning_effort=None self.extra_body=None
ChatCompletionSampler initialized with self.system_message=None self.temperature=0.0 self.max_tokens=4096 self.reasoning_effort=None self.extra_body=None
ChatCompletionSampler initialized with self.system_message=None self.temperature=0.0 self.max_tokens=4096 self.reasoning_effort=None self.extra_body=None
ChatCompletionSampler initialized with self.system_message=None self.temperature=0.0 self.max_tokens=4096 self.reasoning_effort=None self.extra_body=None
ChatCompletionSampler initialized with self.system_message=None self.temperature=0.0 self.max_tokens=4096 self.reasoning_effort=None self.extra_body=None
ChatCompletionSampler initialized with self.system_message=None self.temperature=0.0 self.max_tokens=4096 self.reasoning_effort=None self.extra_body=None
ChatCompletionSampler initialized with self.system_message=None self.temperature=0.0 self.max_tokens=4096 self.reasoning_effort=None self.extra_body=None
ChatCompletionSampler initialized with self.system_message=None self.temperature=0.0 self.max_tokens=4096 self.reasoning_effort=None self.extra_body=None
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 198/198 [10:34<00:00,  3.20s/it]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 198/198 [10:58<00:00,  3.32s/it]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 198/198 [11:12<00:00,  3.39s/it]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 198/198 [11:22<00:00,  3.45s/it]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 198/198 [11:24<00:00,  3.46s/it]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 198/198 [11:34<00:00,  3.51s/it]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 198/198 [11:36<00:00,  3.52s/it]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 198/198 [11:40<00:00,  3.54s/it]
====================██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋          | 188/198 [11:18<00:52,  5.29s/it]
Repeat: 8, mean: 0.515██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊        | 190/198 [11:34<00:22,  2.82s/it]
Scores: ['0.530', '0.470', '0.525', '0.540', '0.500', '0.510', '0.515', '0.530']███████████████████████████████████████████████████████████████████████████████████████████████                                  | 165/198 [11:16<01:44,  3.18s/it]
====================██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                          | 172/198 [11:24<00:59,  2.28s/it]
[METRIC] gpqa_mean_score=0.5151515151515151 labels={"model": "Qwen/Qwen3-Next-80B-A3B-Instruct", "eval": "gpqa", "repeat": 8}███████████████████████████████████████████████████████████▎                        | 174/198 [11:33<01:03,  2.64s/it]
Writing report to /tmp/gpqa_Qwen_Qwen3-Next-80B-A3B-Instruct.html
{'chars': np.float64(6724.373737373738), 'chars:std': np.float64(4060.5329962771793), 'score:std': np.float64(0.4990808815757758), 'scores': ['0.530', '0.470', '0.525', '0.540', '0.500', '0.510', '0.515', '0.530'], 'mean_score': np.float64(0.5151515151515151)}
Writing results to /tmp/gpqa_Qwen_Qwen3-Next-80B-A3B-Instruct.json

gsm8k has no drops:

➜  bench_script lm_eval --model local-completions --tasks gsm8k   --model_args base_url=http://localhost:30000/v1/completions,model=Qwen/Qwen3-Next-80B-A3B-Instruct,num_concurrent=109;
2026-03-10:13:25:42 INFO     [_cli.run:376] Selected Tasks: ['gsm8k']
2026-03-10:13:25:42 WARNING  [evaluator:181] pretrained=None appears to be an instruct or chat variant but chat template is not applied. Recommend setting `apply_chat_template`
        (optionally `fewshot_as_multiturn`).
2026-03-10:13:25:42 INFO     [evaluator:211] Setting random seed to 0 | Setting numpy seed to 1234 | Setting torch manual seed to 1234 | Setting fewshot manual seed to 1234
2026-03-10:13:25:42 INFO     [evaluator:236] Initializing local-completions model, with arguments: {'base_url': 'http://localhost:30000/v1/completions', 'model': 'Qwen/Qwen3-Next-80B-A3B-Instruct', 'num_concurrent': 109}
2026-03-10:13:25:42 INFO     [models.openai_completions:42] Remote tokenizer not supported. Using huggingface tokenizer backend.
2026-03-10:13:25:42 INFO     [models.api_models:172] Using max length 2048 - 1
2026-03-10:13:25:42 INFO     [models.api_models:193] Using tokenizer huggingface
2026-03-10:13:25:46 INFO     [tasks:700] Selected tasks:
2026-03-10:13:25:46 INFO     [tasks:691] Task: gsm8k (gsm8k/gsm8k.yaml)
2026-03-10:13:25:46 INFO     [evaluator:314] gsm8k: Using gen_kwargs: {'until': ['Question:', '</s>', '<|im_end|>'], 'do_sample': False, 'temperature': 0.0}
2026-03-10:13:25:46 INFO     [api.task:311] Building contexts for gsm8k on rank 0...
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [00:04<00:00, 295.64it/s]
2026-03-10:13:25:50 INFO     [evaluator:584] Running generate_until requests
Requesting API: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [02:43<00:00,  8.05it/s]
fatal: not a git repository (or any of the parent directories): .git
2026-03-10:13:28:44 INFO     [loggers.evaluation_tracker:316] Output path not provided, skipping saving results aggregated
local-completions ({'base_url': 'http://localhost:30000/v1/completions', 'model': 'Qwen/Qwen3-Next-80B-A3B-Instruct', 'num_concurrent': 109}), gen_kwargs: ({}), limit: None, num_fewshot: None, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8567|±  |0.0097|
|     |       |strict-match    |     5|exact_match|↑  |0.8188|±  |0.0106|

Benchmarking and Profiling

Server:

CUDA_VISIBLE_DEVICES=4,5,6,7 python3 -m sglang.launch_server \
  --model Qwen/Qwen3-Next-80B-A3B-Instruct \
  --tp 4 \
  --speculative-num-steps 3 \
  --speculative-eagle-topk 1 \
  --speculative-num-draft-tokens 4 \
  --speculative-algo NEXTN \
  --disable-radix-cache

Benchmark:

TTFT speedup: (14993-13842)/14993 = 7%
E2E speedup: (23203-21247)/23203 = 8%

python3 -m sglang.bench_serving   --backend sglang   --host 127.0.0.1 --port 30000 --dataset-name random   --random-input-len 8000 --random-output 1500 --dataset-path /data/ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 200
Main:
============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Max request concurrency:                 not set
Successful requests:                     200
Benchmark duration (s):                  39.92
Total input tokens:                      788704
Total input text tokens:                 788704
Total generated tokens:                  153723
Total generated tokens (retokenized):    153714
Request throughput (req/s):              5.01
Input token throughput (tok/s):          19757.93
Output token throughput (tok/s):         3850.93
Peak output token throughput (tok/s):    6826.00
Peak concurrent requests:                200
Total token throughput (tok/s):          23608.86
Concurrency:                             116.25
Accept length:                           3.74
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   23203.47
Median E2E Latency (ms):                 23106.15
P90 E2E Latency (ms):                    36472.25
P99 E2E Latency (ms):                    39378.64
---------------Time to First Token----------------
Mean TTFT (ms):                          14993.99
Median TTFT (ms):                        15089.11
P99 TTFT (ms):                           33932.34
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          11.43
Median TPOT (ms):                        10.85
P99 TPOT (ms):                           21.08
---------------Inter-Token Latency----------------
Mean ITL (ms):                           10.69
Median ITL (ms):                         4.95
P95 ITL (ms):                            33.07
P99 ITL (ms):                            61.39
Max ITL (ms):                            415.92
==================================================

PR:
============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Max request concurrency:                 not set
Successful requests:                     200
Benchmark duration (s):                  36.21
Total input tokens:                      788704
Total input text tokens:                 788704
Total generated tokens:                  153723
Total generated tokens (retokenized):    153691
Request throughput (req/s):              5.52
Input token throughput (tok/s):          21780.55
Output token throughput (tok/s):         4245.16
Peak output token throughput (tok/s):    6716.00
Peak concurrent requests:                200
Total token throughput (tok/s):          26025.71
Concurrency:                             117.35
Accept length:                           3.75
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   21247.10
Median E2E Latency (ms):                 21214.98
P90 E2E Latency (ms):                    33039.39
P99 E2E Latency (ms):                    35386.80
---------------Time to First Token----------------
Mean TTFT (ms):                          13842.98
Median TTFT (ms):                        13954.41
P99 TTFT (ms):                           30856.84
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          10.20
Median TPOT (ms):                        9.93
P99 TPOT (ms):                           17.83
---------------Inter-Token Latency----------------
Mean ITL (ms):                           9.65
Median ITL (ms):                         4.91
P95 ITL (ms):                            30.33
P99 ITL (ms):                            56.00
Max ITL (ms):                            344.34
==================================================

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@yuan-luo
Copy link
Copy Markdown
Collaborator Author

/tag-and-rerun-ci

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request focuses on a critical optimization within the GDN attention mechanism by reordering the dimensions of the recurrent state. By swapping the K and V dimensions, the system achieves more efficient memory access on GPUs, directly translating to notable improvements in both the time it takes to generate the first token and the overall inference latency. This enhancement is particularly beneficial in decode scenarios where memory access patterns are crucial for performance.

Highlights

  • Memory Layout Optimization: The recurrent state memory layout in GDN attention has been changed from [N, HV, K, V] to [N, HV, V, K] to improve memory access patterns and throughput.
  • Performance Improvement: This change significantly enhances GPU's coalesced memory access, leading to a 7% speedup in Time to First Token (TTFT) and an 8% speedup in End-to-End (E2E) latency during benchmarking.
  • Kernel Adjustments: The underlying Triton kernels (chunk_delta_h, chunk_o, fused_recurrent) were updated to correctly handle the new [V, K] memory layout, including changes to tensor initialization, memory offsets, block pointers, and matrix multiplication operations.
  • Docstring and Example Updates: Docstrings and example code snippets were modified to reflect the new state tensor shape, ensuring consistency and clarity for developers.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/layers/attention/fla/chunk.py
    • Updated docstrings for initial_state and final_state to reflect the new [N, H, V, K] shape.
    • Modified the return type annotation for final_state to match the new shape.
  • python/sglang/srt/layers/attention/fla/chunk_delta_h.py
    • Adjusted the initialization shape of b_h tensors from [64, BV] to [BV, 64].
    • Updated memory offset calculations for h, h0, and ht to use V * K instead of K * V.
    • Modified tl.make_block_ptr calls to align with the (V, K) tensor layout.
    • Transposed b_h and b_k in tl.dot operations to match the new dimension order.
    • Updated the exp function application from [:, None] to [None, :].
    • Changed the creation shape of the h tensor from (B, NT, H, K, V) to (B, NT, H, V, K).
  • python/sglang/srt/layers/attention/fla/chunk_o.py
    • Updated the memory offset calculation for h to use V * K.
    • Modified tl.make_block_ptr for p_h to reflect the (V, K) tensor layout.
    • Transposed b_h in tl.dot operations.
  • python/sglang/srt/layers/attention/fla/fused_recurrent.py
    • Modified mask_h calculation to swap mask_k and mask_v indexing.
    • Changed b_h tensor initialization from [BK, BV] to [BV, BK].
    • Updated memory offset calculations for p_h0 and p_ht to align with the [V, K] layout.
    • Adjusted exp function application from [:, None] to [None, :].
    • Modified tl.sum arguments for b_v update to reflect the new dimension order.
    • Updated b_h and b_o calculation logic to match the [V, K] layout.
    • Changed the creation shape of final_state from (N, HV, K, V) to (N, HV, V, K).
    • Updated docstrings and example h0 shape to reflect the new [N, HV, V, K] layout.
  • python/sglang/srt/layers/attention/fla/kda.py
    • Changed the creation shape of final_state from (T, HV, K, V) to (T, HV, V, K).
Activity
  • The author provided benchmarking results showing a 7% TTFT speedup and an 8% E2E speedup.
  • The pull request includes a checklist for code formatting, unit tests, documentation, and benchmarking, with some items yet to be checked.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request successfully refactors the attention state memory layout from [N, HV, K, V] to [N, HV, V, K]. This change is well-motivated by the goal of improving memory access patterns and throughput, especially for GPU coalesced memory access, as evidenced by the provided benchmarking results showing improvements in TTFT and E2E latency. The modifications across the Triton kernels (chunk_delta_h.py, chunk_o.py, fused_recurrent.py) correctly adapt the tensor shapes, strides, block pointers, and matrix multiplication operations to the new layout. The docstring updates accurately reflect these internal changes.

Comment thread python/sglang/srt/layers/attention/fla/chunk_delta_h.py
Comment thread python/sglang/srt/layers/attention/fla/chunk_delta_h.py
Comment thread python/sglang/srt/layers/attention/fla/fused_recurrent.py
Comment thread python/sglang/srt/layers/attention/fla/fused_recurrent.py
Copy link
Copy Markdown
Collaborator

@kaixih kaixih left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have exercised the prefill path and all my tests pass. My test is here.

Left one comment in the MTP path.

Comment thread python/sglang/srt/layers/attention/fla/fused_recurrent.py Outdated
@yuan-luo
Copy link
Copy Markdown
Collaborator Author

Fixed KDA CI failure. Now both KDA and GDN are using the VK layout.

➜  sglang_dev2 git:(fla_swap_kv) ✗ CUDA_VISIBLE_DEVICES=4,5,6,7 python ./test/registered/attention/test_kda_kernels.py -q -v
test_kda_fused_sigmoid_gating_recurrent (__main__.TestKDAFusedSigmoidGatingRecurrent.test_kda_fused_sigmoid_gating_recurrent) ... abs_diff_out=tensor(0., device='cuda:0', dtype=torch.bfloat16), abs_diff_state=tensor(4.4703e-08, device='cuda:0')
ok

----------------------------------------------------------------------
Ran 1 test in 4.319s

OK

@yuan-luo
Copy link
Copy Markdown
Collaborator Author

Kimi Linear model test passed.

➜  sglang_dev2 git:(fla_swap_kv) ✗ CUDA_VISIBLE_DEVICES=4,5,6,7 python ./test/registered/models/test_kimi_linear_models.py
.....
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:07<00:00, 28.24it/s]
Accuracy: 0.890
Invalid: 0.000
Latency: 7.123 s
Output throughput: 2739.025 token/s
metrics={'accuracy': np.float64(0.89), 'invalid': np.float64(0.0), 'latency': 7.122608389006928, 'output_throughput': 2739.024657049838}
.
----------------------------------------------------------------------
Ran 1 test in 108.185s

OK

@yuan-luo yuan-luo requested a review from hanming-lu March 11, 2026 05:00
@yuan-luo
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

@yuan-luo
Copy link
Copy Markdown
Collaborator Author

/tag-and-rerun-ci

@yuan-luo
Copy link
Copy Markdown
Collaborator Author

Verified that this PR worked correctly in the mamba extra_buffer mode.

➜  sglang_dev2 git:(fla_swap_kv) ✗ CUDA_VISIBLE_DEVICES=4,5,6,7 python3 -m sglang.launch_server \
  --model Qwen/Qwen3-Next-80B-A3B-Instruct \
  --tp 4 \
  --speculative-num-steps 3 \
  --speculative-eagle-topk 1 \
  --speculative-num-draft-tokens 4 \
  --speculative-algo NEXTN \
  --mamba-scheduler-strategy extra_buffer --port 30000
[2026-03-11 06:54:58] INFO utils.py:148: Note: detected 224 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2026-03-11 06:54:58] INFO utils.py:151: Note: NumExpr detected 224 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
[2026-03-11 06:54:58] INFO utils.py:164: NumExpr defaulting to 16 threads.
/usr/local/lib/python3.12/dist-packages/sglang/launch_server.py:51: UserWarning: 'python -m sglang.launch_server' is still supported, but 'sglang serve' is the recommended entrypoint.
  Example: sglang serve --model-path <model> [options]
...
[2026-03-11 06:55:54] INFO:     Application startup complete.
[2026-03-11 06:55:54] INFO:     Uvicorn running on socket ('127.0.0.1', 30000) (Press CTRL+C to quit)
[2026-03-11 06:55:55] INFO:     127.0.0.1:58886 - "GET /model_info HTTP/1.1" 200 OK
[2026-03-11 06:55:55 TP0] Prefill batch, #new-seq: 1, #new-token: 6, #cached-token: 0, full token usage: 0.00, mamba usage: 0.00, #running-req: 0, #queue-req: 0, input throughput (token/s): 0.00, cuda graph: False
[2026-03-11 06:55:56] INFO:     127.0.0.1:58892 - "POST /generate HTTP/1.1" 200 OK
[2026-03-11 06:55:56] The server is fired up and ready to roll!
[2026-03-11 06:56:29 TP0] Prefill batch, #new-seq: 1, #new-token: 33, #cached-token: 0, full token usage: 0.00, mamba usage: 0.00, #running-req: 0, #queue-req: 0, input throughput (token/s): 0.18, cuda graph: False
[2026-03-11 06:56:30 TP0] Decode batch, #running-req: 1, #full token: 135, full token usage: 0.00, mamba num: 2, mamba usage: 0.00, accept len: 2.73, accept rate: 0.68, cuda graph: True, gen throughput (token/s): 1.38, #queue-req: 0
[2026-03-11 06:56:30] INFO:     127.0.0.1:50076 - "POST /v1/chat/completions HTTP/1.1" 200 OK
➜  bench_script python test_openai.py
ChatCompletion(id='d67c6a6331bd45ea8fc8564237039846', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content='Sure! Here are three countries and their capitals:\n\n1. **France** – Paris  \n2. **Japan** – Tokyo  \n3. **Brazil** – Brasília  \n\n### How I Ranked Them:\nI ranked these countries **alphabetically by country name**:\n\n- **Brazil** (B)  \n- **France** (F)  \n- **Japan** (J)  \n\nThis is a neutral, objective ranking method — no value judgments about size, population, economy, or cultural influence. Alphabetical order ensures fairness and consistency. If you’d like them ranked by population, GDP, or something else, just let me know!', refusal=None, role='assistant', annotations=None, audio=None, function_call=None, tool_calls=None, reasoning_content=None), matched_stop=151645)], created=1773212452, model='default', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=129, prompt_tokens=33, total_tokens=162, completion_tokens_details=None, prompt_tokens_details=None, reasoning_tokens=0), metadata={'weight_version': 'default'})

@yuan-luo
Copy link
Copy Markdown
Collaborator Author

gsm8k no drop in mamba extra_buffer mode:

➜  bench_script lm_eval --model local-completions --tasks gsm8k   --model_args base_url=http://localhost:30000/v1/completions,model=Qwen/Qwen3-Next-80B-A3B-Instruct,num_concurrent=109;
2026-03-11:07:05:18 INFO     [_cli.run:376] Selected Tasks: ['gsm8k']
2026-03-11:07:05:18 WARNING  [evaluator:181] pretrained=None appears to be an instruct or chat variant but chat template is not applied. Recommend setting `apply_chat_template`
        (optionally `fewshot_as_multiturn`).
2026-03-11:07:05:18 INFO     [evaluator:211] Setting random seed to 0 | Setting numpy seed to 1234 | Setting torch manual seed to 1234 | Setting fewshot manual seed to 1234
2026-03-11:07:05:18 INFO     [evaluator:236] Initializing local-completions model, with arguments: {'base_url': 'http://localhost:30000/v1/completions', 'model': 'Qwen/Qwen3-Next-80B-A3B-Instruct', 'num_concurrent': 109}
2026-03-11:07:05:18 INFO     [models.openai_completions:42] Remote tokenizer not supported. Using huggingface tokenizer backend.
2026-03-11:07:05:18 INFO     [models.api_models:172] Using max length 2048 - 1
2026-03-11:07:05:18 INFO     [models.api_models:193] Using tokenizer huggingface
2026-03-11:07:05:22 INFO     [tasks:700] Selected tasks:
2026-03-11:07:05:22 INFO     [tasks:691] Task: gsm8k (gsm8k/gsm8k.yaml)
2026-03-11:07:05:22 INFO     [evaluator:314] gsm8k: Using gen_kwargs: {'until': ['Question:', '</s>', '<|im_end|>'], 'do_sample': False, 'temperature': 0.0}
2026-03-11:07:05:22 INFO     [api.task:311] Building contexts for gsm8k on rank 0...
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [00:04<00:00, 296.29it/s]
2026-03-11:07:05:26 INFO     [evaluator:584] Running generate_until requests
Requesting API: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [02:25<00:00,  9.08it/s]
fatal: not a git repository (or any of the parent directories): .git
2026-03-11:07:08:01 INFO     [loggers.evaluation_tracker:316] Output path not provided, skipping saving results aggregated
local-completions ({'base_url': 'http://localhost:30000/v1/completions', 'model': 'Qwen/Qwen3-Next-80B-A3B-Instruct', 'num_concurrent': 109}), gen_kwargs: ({}), limit: None, num_fewshot: None, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8529|±  |0.0098|
|     |       |strict-match    |     5|exact_match|↑  |0.8143|±  |0.0107|

@yuan-luo
Copy link
Copy Markdown
Collaborator Author

Added gpqa test with mamba extra_buffer enabled, acc no drops:

➜  bench_script python3 -m sglang.test.run_eval --port 30000 --eval-name gpqa --num-examples 198 --max-tokens 4096 --repeat 8
ChatCompletionSampler initialized with self.system_message=None self.temperature=0.0 self.max_tokens=4096 self.reasoning_effort=None self.extra_body=None
ChatCompletionSampler initialized with self.system_message=None self.temperature=0.0 self.max_tokens=4096 self.reasoning_effort=None self.extra_body=None
ChatCompletionSampler initialized with self.system_message=None self.temperature=0.0 self.max_tokens=4096 self.reasoning_effort=None self.extra_body=None
ChatCompletionSampler initialized with self.system_message=None self.temperature=0.0 self.max_tokens=4096 self.reasoning_effort=None self.extra_body=None
ChatCompletionSampler initialized with self.system_message=None self.temperature=0.0 self.max_tokens=4096 self.reasoning_effort=None self.extra_body=None
ChatCompletionSampler initialized with self.system_message=None self.temperature=0.0 self.max_tokens=4096 self.reasoning_effort=None self.extra_body=None
ChatCompletionSampler initialized with self.system_message=None self.temperature=0.0 self.max_tokens=4096 self.reasoning_effort=None self.extra_body=None
ChatCompletionSampler initialized with self.system_message=None self.temperature=0.0 self.max_tokens=4096 self.reasoning_effort=None self.extra_body=None
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 198/198 [10:34<00:00,  3.20s/it]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 198/198 [10:58<00:00,  3.32s/it]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 198/198 [11:12<00:00,  3.39s/it]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 198/198 [11:22<00:00,  3.45s/it]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 198/198 [11:24<00:00,  3.46s/it]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 198/198 [11:34<00:00,  3.51s/it]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 198/198 [11:36<00:00,  3.52s/it]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 198/198 [11:40<00:00,  3.54s/it]
====================██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋          | 188/198 [11:18<00:52,  5.29s/it]
Repeat: 8, mean: 0.515██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊        | 190/198 [11:34<00:22,  2.82s/it]
Scores: ['0.530', '0.470', '0.525', '0.540', '0.500', '0.510', '0.515', '0.530']███████████████████████████████████████████████████████████████████████████████████████████████                                  | 165/198 [11:16<01:44,  3.18s/it]
====================██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                          | 172/198 [11:24<00:59,  2.28s/it]
[METRIC] gpqa_mean_score=0.5151515151515151 labels={"model": "Qwen/Qwen3-Next-80B-A3B-Instruct", "eval": "gpqa", "repeat": 8}███████████████████████████████████████████████████████████▎                        | 174/198 [11:33<01:03,  2.64s/it]
Writing report to /tmp/gpqa_Qwen_Qwen3-Next-80B-A3B-Instruct.html
{'chars': np.float64(6724.373737373738), 'chars:std': np.float64(4060.5329962771793), 'score:std': np.float64(0.4990808815757758), 'scores': ['0.530', '0.470', '0.525', '0.540', '0.500', '0.510', '0.515', '0.530'], 'mean_score': np.float64(0.5151515151515151)}
Writing results to /tmp/gpqa_Qwen_Qwen3-Next-80B-A3B-Instruct.json

@yuan-luo
Copy link
Copy Markdown
Collaborator Author

Added a test case for the new VK layout.

➜  sglang_dev2 git:(fla_swap_kv) ✗ python -m pytest test_chunk_gated_delta_rule.py
=============================================================================================================== test session starts ===============================================================================================================
platform linux -- Python 3.12.3, pytest-9.0.2, pluggy-1.6.0
rootdir: /sgl-workspace/sglang_dev2
plugins: asyncio-1.3.0, anyio-4.12.1, hydra-core-1.3.2, typeguard-4.4.4
asyncio: mode=Mode.STRICT, debug=False, asyncio_default_fixture_loop_scope=None, asyncio_default_test_loop_scope=function
collected 28 items

test_chunk_gated_delta_rule.py ............................                                                                                                                                                                                 [100%]

=============================================================================================================== 28 passed in 4.30s ================================================================================================================

@yuan-luo
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

2 similar comments
@yuan-luo
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

@yuan-luo
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

@ispobock ispobock merged commit 649d6f2 into sgl-project:main Mar 12, 2026
334 of 357 checks passed
liubiyongge pushed a commit to liubiyongge/sglang that referenced this pull request Mar 13, 2026
Wangzheee pushed a commit to Wangzheee/sglang that referenced this pull request Mar 21, 2026
JustinTong0323 pushed a commit to JustinTong0323/sglang that referenced this pull request Apr 7, 2026
yhyang201 pushed a commit to yhyang201/sglang that referenced this pull request Apr 22, 2026
IzacharyI added a commit to IzacharyI/sglang that referenced this pull request Apr 23, 2026
IzacharyI added a commit to IzacharyI/sglang that referenced this pull request Apr 24, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants