Skip to content

[PERF] Decouple projections from GDN custom op#27512

Merged
simon-mo merged 4 commits intovllm-project:mainfrom
CentML:vadim/refac-gdn
Nov 4, 2025
Merged

[PERF] Decouple projections from GDN custom op#27512
simon-mo merged 4 commits intovllm-project:mainfrom
CentML:vadim/refac-gdn

Conversation

@vadiklyutiy
Copy link
Copy Markdown
Collaborator

@vadiklyutiy vadiklyutiy commented Oct 25, 2025

Purpose

This PR is refactoring of GDN.

The main goal is to allow wider using of torch.compile.

  1. Separated forward pass of GDN attention into three distinct pieces: Input Projection, Core Attention, Output Projection. Before projections was in the GDN custom op and were not covered by torch.compile.
  2. Added RMSNormGated class that implements torch native gated rmsnorm and use it for GDN. torch.compile creates a good code for RMSNormGated even better than custom triton kernel used before.

Functional Test Result

lm_eval
Before

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8491|±  |0.0099|
|     |       |strict-match    |     5|exact_match|↑  |0.8059|±  |0.0109|

After

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8544|±  |0.0097|
|     |       |strict-match    |     5|exact_match|↑  |0.8127|±  |0.0107|

Perf Test Result

Server

VLLM_ATTENTION_BACKEND=FLASH_ATTN VLLM_USE_FLASHINFER_MOE_FP16=1 vllm serve Qwen/Qwen3-Next-80B-A3B-Instruct -tp 4 --enable-expert-parallel --no-enable-prefix-caching --async-scheduling --max_cudagraph_capture_size=2048

Prefill

vllm bench serve --backend vllm --model Qwen/Qwen3-Next-80B-A3B-Instruct --endpoint /v1/completions --dataset-name random --random-input 8192 --random-output 1 --max-concurrency 512 --num-prompt 512 --ignore-eos

Before: Total Token throughput (tok/s): 104098.78
After: Total Token throughput (tok/s): 105270.70
Speedup: 1.1%

Decode1

vllm bench serve --backend vllm --model Qwen/Qwen3-Next-80B-A3B-Instruct --endpoint /v1/completions --dataset-name random --random-input 32 --random-output 1024 --max-concurrency 512 --num-prompt 512 --ignore-eos

Before: Output token throughput (tok/s): 19212.17
After: Output token throughput (tok/s): 22384.37
Speedup: 16.5%

Decode2

vllm bench serve --backend vllm --model Qwen/Qwen3-Next-80B-A3B-Instruct --endpoint /v1/completions --dataset-name random --random-input 32 --random-output 1024 --max-concurrency 1024 --num-prompt 1024 --ignore-eos

Before: Output token throughput (tok/s): 28821.37
After: Output token throughput (tok/s): 30298.90
Speed up: 5.1%

Decode3
Server

VLLM_ATTENTION_BACKEND=FLASH_ATTN VLLM_USE_FLASHINFER_MOE_FP16=1 vllm serve Qwen/Qwen3-Next-80B-A3B-Instruct -tp 4 --enable-expert-parallel --no-enable-prefix-caching --async-scheduling

(without increasing --max_cudagraph_capture_size)

vllm bench serve --backend vllm --model Qwen/Qwen3-Next-80B-A3B-Instruct --endpoint /v1/completions --dataset-name random --random-input 32 --random-output 1024 --max-concurrency 1024 --num-prompt 1024 --ignore-eos

Before: Output token throughput (tok/s): 16586.93
After: Output token throughput (tok/s): 18953.92
Speed up: 14.3%

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the Gated Delta Net (GDN) attention mechanism to improve torch.compile compatibility and performance. By decoupling the input/output projections from the core custom operator and introducing a native PyTorch RMSNormGated layer, the changes yield significant decode throughput improvements. The refactoring is well-executed and the code is clear. I have one high-severity suggestion regarding a local import in a performance-critical path, which should be moved to the top level of the module to adhere to best practices and avoid potential overhead.

Comment thread vllm/model_executor/layers/layernorm.py
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread vllm/model_executor/models/qwen3_next.py
@ZJY0516
Copy link
Copy Markdown
Member

ZJY0516 commented Oct 27, 2025

CC @heheda12345

@vadiklyutiy
Copy link
Copy Markdown
Collaborator Author

@ALL
Could you please take a look at this PR

@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Oct 30, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @vadiklyutiy.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 30, 2025
Copy link
Copy Markdown
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM but I'll let someone more familiar with Qwen3 approve

Comment thread vllm/model_executor/layers/layernorm.py Outdated
Comment thread vllm/model_executor/layers/layernorm.py Outdated
@vadiklyutiy
Copy link
Copy Markdown
Collaborator Author

LGTM but I'll let someone more familiar with Qwen3 approve

is there someone familiar with Qwen3-next except @sighingnow ?

@mgoin
Copy link
Copy Markdown
Member

mgoin commented Oct 31, 2025

cc @tlrmchlsmth

@mergify mergify bot removed the needs-rebase label Oct 31, 2025
Copy link
Copy Markdown
Collaborator

@heheda12345 heheda12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@heheda12345
Copy link
Copy Markdown
Collaborator

@codex review

@chatgpt-codex-connector
Copy link
Copy Markdown

Codex Review: Didn't find any major issues. Delightful!

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

@heheda12345 heheda12345 enabled auto-merge (squash) October 31, 2025 06:07
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 31, 2025
Copy link
Copy Markdown
Member

@youkaichao youkaichao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good! cc @zhiyuan1i we should be able to do similar optimization for kimi linear.

@ZJY0516
Copy link
Copy Markdown
Member

ZJY0516 commented Oct 31, 2025

looks good! cc @zhiyuan1i we should be able to do similar optimization for kimi linear.

I have made a pr #27871 for kimi.

@youkaichao youkaichao disabled auto-merge October 31, 2025 13:42
@vadiklyutiy vadiklyutiy force-pushed the vadim/refac-gdn branch 2 times, most recently from 39cd84b to 090a44b Compare November 1, 2025 00:05
@vadiklyutiy
Copy link
Copy Markdown
Collaborator Author

Regarding CI fails. It looks like they are not related to this PR and a lot of latest commits to fail also has similar fails.
Is it ok to merge this PR?

@heheda12345
Copy link
Copy Markdown
Collaborator

@codex review

@chatgpt-codex-connector
Copy link
Copy Markdown

Codex Review: Didn't find any major issues. Hooray!

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

@heheda12345 heheda12345 enabled auto-merge (squash) November 1, 2025 22:53
auto-merge was automatically disabled November 3, 2025 09:07

Head branch was pushed to by a user without write access

@vadiklyutiy vadiklyutiy force-pushed the vadim/refac-gdn branch 2 times, most recently from 70d2048 to c3197c6 Compare November 3, 2025 12:21
@vadiklyutiy
Copy link
Copy Markdown
Collaborator Author

CI results are weird. I can't repeat them locally.
Try to disable some parts to localize

@heheda12345
Copy link
Copy Markdown
Collaborator

Can you fix the pre-commit and rebase your PR from main? There are many CI fixes on main branch these days.

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
@vadiklyutiy
Copy link
Copy Markdown
Collaborator Author

Can you fix the pre-commit and rebase your PR from main? There are many CI fixes on main branch these days.

Rebased. Still the same and still can't reproduce locally :(

@simon-mo simon-mo merged commit 5fd8f02 into vllm-project:main Nov 4, 2025
54 of 58 checks passed
vadiklyutiy added a commit to CentML/vllm that referenced this pull request Nov 4, 2025
…7512)

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
@vadiklyutiy vadiklyutiy deleted the vadim/refac-gdn branch November 5, 2025 00:41
ZhengHongming888 pushed a commit to ZhengHongming888/vllm that referenced this pull request Nov 8, 2025
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
devpatelio pushed a commit to SumanthRH/vllm that referenced this pull request Nov 29, 2025
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants