Skip to content

[PERF] Decouple projections from GDN custom op. Attempt 2#28083

Merged
simon-mo merged 2 commits intovllm-project:mainfrom
CentML:vadim/refac-gdn2
Nov 6, 2025
Merged

[PERF] Decouple projections from GDN custom op. Attempt 2#28083
simon-mo merged 2 commits intovllm-project:mainfrom
CentML:vadim/refac-gdn2

Conversation

@vadiklyutiy
Copy link
Copy Markdown
Collaborator

@vadiklyutiy vadiklyutiy commented Nov 5, 2025

Purpose

The second attempt for #27512

This PR is refactoring of GDN.

The main goal is to allow wider using of torch.compile.

  1. Separated forward pass of GDN attention into three distinct pieces: Input Projection, Core Attention, Output Projection. Before projections was in the GDN custom op and were not covered by torch.compile.
  2. Added RMSNormGated class that implements torch native gated rmsnorm and use it for GDN. torch.compile creates a good code for RMSNormGated even better than custom triton kernel used before.

Functional Test Result

lm_eval
Before

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8491|±  |0.0099|
|     |       |strict-match    |     5|exact_match|↑  |0.8059|±  |0.0109|

After

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8544|±  |0.0097|
|     |       |strict-match    |     5|exact_match|↑  |0.8127|±  |0.0107|

Perf Test Result

Server

VLLM_ATTENTION_BACKEND=FLASH_ATTN VLLM_USE_FLASHINFER_MOE_FP16=1 vllm serve Qwen/Qwen3-Next-80B-A3B-Instruct -tp 4 --enable-expert-parallel --no-enable-prefix-caching --async-scheduling --max_cudagraph_capture_size=2048

Prefill

vllm bench serve --backend vllm --model Qwen/Qwen3-Next-80B-A3B-Instruct --endpoint /v1/completions --dataset-name random --random-input 8192 --random-output 1 --max-concurrency 512 --num-prompt 512 --ignore-eos

Before: Total Token throughput (tok/s): 104098.78
After: Total Token throughput (tok/s): 105270.70
Speedup: 1.1%

Decode1

vllm bench serve --backend vllm --model Qwen/Qwen3-Next-80B-A3B-Instruct --endpoint /v1/completions --dataset-name random --random-input 32 --random-output 1024 --max-concurrency 512 --num-prompt 512 --ignore-eos

Before: Output token throughput (tok/s): 19212.17
After: Output token throughput (tok/s): 22384.37
Speedup: 16.5%

Decode2

vllm bench serve --backend vllm --model Qwen/Qwen3-Next-80B-A3B-Instruct --endpoint /v1/completions --dataset-name random --random-input 32 --random-output 1024 --max-concurrency 1024 --num-prompt 1024 --ignore-eos

Before: Output token throughput (tok/s): 28821.37
After: Output token throughput (tok/s): 30298.90
Speed up: 5.1%

Decode3
Server

VLLM_ATTENTION_BACKEND=FLASH_ATTN VLLM_USE_FLASHINFER_MOE_FP16=1 vllm serve Qwen/Qwen3-Next-80B-A3B-Instruct -tp 4 --enable-expert-parallel --no-enable-prefix-caching --async-scheduling

(without increasing --max_cudagraph_capture_size)

vllm bench serve --backend vllm --model Qwen/Qwen3-Next-80B-A3B-Instruct --endpoint /v1/completions --dataset-name random --random-input 32 --random-output 1024 --max-concurrency 1024 --num-prompt 1024 --ignore-eos

Before: Output token throughput (tok/s): 16586.93
After: Output token throughput (tok/s): 18953.92
Speed up: 14.3%

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the Gated Delta Net (GDN) attention to improve torch.compile performance by decoupling input and output projections from the core custom operator. The changes are well-implemented and the introduction of a native RMSNormGated layer is a good addition. I have one suggestion to optimize the new RMSNormGated layer's performance.

out = x_normed * self.weight
else:
# Group RMS norm
from einops import rearrange
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Importing einops inside the forward_native method can introduce a performance overhead, as the import statement will be executed every time this method is called. It's better to move this import to the top of the file to ensure it's only executed once when the module is loaded.

@simon-mo simon-mo added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 5, 2025
@vadiklyutiy vadiklyutiy self-assigned this Nov 5, 2025
@vadiklyutiy vadiklyutiy marked this pull request as draft November 5, 2025 01:52
@vadiklyutiy vadiklyutiy marked this pull request as ready for review November 5, 2025 15:14
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines 436 to +487
hidden_states: torch.Tensor,
output: torch.Tensor,
):
return torch.ops.vllm.gdn_attention(
hidden_states,
output,
"""
Forward pass with three parts:
1. Input projection
2. Core attention (custom op)
3. Output projection
"""
num_tokens = hidden_states.size(0)

# ============================================================
# Part 1: Input Projection
# ============================================================
projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states)
projected_states_ba, _ = self.in_proj_ba(hidden_states)
query, key, value, z, b, a = self.fix_query_key_value_ordering(
projected_states_qkvz, projected_states_ba
)
query, key, value = map(
lambda x: rearrange(x, "l p d -> l (p d)"), (query, key, value)
)
mixed_qkv = torch.cat((query, key, value), dim=-1)

# ============================================================
# Part 2: Core Attention (Custom Op)
# ============================================================
core_attn_out = torch.zeros(
(num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim),
dtype=hidden_states.dtype,
device=hidden_states.device,
)

torch.ops.vllm.gdn_attention_core(
mixed_qkv,
b,
a,
core_attn_out,
self.prefix,
)

def _forward(
# ============================================================
# Part 3: Output Projection
# ============================================================
z_shape_og = z.shape
# Reshape input data into 2D tensor
core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
z = z.reshape(-1, z.shape[-1])
core_attn_out = self.norm(core_attn_out, z)
core_attn_out = core_attn_out.reshape(z_shape_og)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Preserve grad connection when writing core attention output

The refactor now preallocates core_attn_out with torch.zeros and _forward_core copies the computed attention values into this tensor (core_attn_out[:num_actual_tokens] = …). Because core_attn_out is a detached leaf that never participates in any differentiable operation, the subsequent normalization and output projection operate on a tensor that does not require gradients, so no gradient can flow back to mixed_qkv, b, or a. In the previous implementation core_attn_out was the result of the computations themselves and therefore carried a grad_fn. With the new code, any training or fine‑tuning pass will receive zero gradients through the GDN block. Consider returning the computed tensor from _forward_core (or constructing core_attn_out from the computation rather than copying into a zero buffer) so the graph remains connected.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Member

@youkaichao youkaichao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

idea looks good, please make sure tests are good too.

@vadiklyutiy
Copy link
Copy Markdown
Collaborator Author

idea looks good, please make sure tests are good too.

sure, I am working on it

@ProExpertProg
Copy link
Copy Markdown
Collaborator

Please fix DCO, also what was the reason for the revert originally?

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
@vadiklyutiy
Copy link
Copy Markdown
Collaborator Author

what was the reason for the revert originally?

#27512 had a bit wired CI fails. They didn't look as related to my changes: one test failed on CUDAART init, another out of memory. I couldn’t reproduce them locally, even with a configuration very close to the CI environment. I asked for any idea that can help. Simon, was a bit too kind, took into account that CI is a bit unstable last several day, and merged the PR. But several next commit's CI also failed on the same tests and we decided to revert.

@vadiklyutiy
Copy link
Copy Markdown
Collaborator Author

picking here and there I found the reason of CI fails. It isn't root cause but good enough IMO.

This is logs with fails
https://buildkite.com/vllm/ci/builds/37607/steps/canvas

after applying e7a9260
fails disappeared
https://buildkite.com/vllm/ci/builds/37696/steps/canvas

e7a9260 moved import from global scope to function scope. By some reason it caused fails. Hope it is a acceptable workaround.

@simon-mo simon-mo merged commit b6a248b into vllm-project:main Nov 6, 2025
59 checks passed
@youkaichao
Copy link
Copy Markdown
Member

e7a9260 moved import from global scope to function scope. By some reason it caused fails. Hope it is a acceptable workaround.

likely the code unconditionaly initializes cuda runtime, which we need to be careful about.

@vadiklyutiy vadiklyutiy deleted the vadim/refac-gdn2 branch November 7, 2025 11:38
ZhengHongming888 pushed a commit to ZhengHongming888/vllm that referenced this pull request Nov 8, 2025
…ct#28083)

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
wangxiyuan added a commit to vllm-project/vllm-ascend that referenced this pull request Nov 26, 2025
Bump vLLM version to v0.11.2

What's broken and changed by vLLM:
1. structured_output is broken by
vllm-project/vllm#26866
2. get_mrope_input_positions is broken by
vllm-project/vllm#28399
3. graph mode is broken by
vllm-project/vllm#25110 we'll upgrade torch to
2.8 to fix the problem later
4. embedding is broken by
vllm-project/vllm#27583
5. `get_attn_backend_cls` and attention backend is broken are broken by
vllm-project/vllm#28534
6. spec decode is broken by
vllm-project/vllm#28771
7. sp feature is broken by
vllm-project/vllm#27126
8. mtp is broken by vllm-project/vllm#27922
9. lora is broken by vllm-project/vllm#21068
10. execute_model is broken by
vllm-project/vllm#26866
11. `VLLM_DISABLE_SHARED_EXPERTS_STREAM` env is broken by
vllm-project/vllm#28159
12. kv cahe is broken by vllm-project/vllm#27753
13. dp is broken by vllm-project/vllm#25110

 
What's broken and changed by ourself:
1. qwen vl is broken by vllm-project/vllm#28455
We'll remove model files in the future to avoid this kind of error
2. Engine core is broken by
vllm-project/vllm#23691 We'll remove the patch
file in the future.
3. Ascend scheduler is broken by
vllm-project/vllm#28733 We'll remove ascend
scheudler later.
4. qwen3-next is broken by
vllm-project/vllm#28083 We'll remove model files
in the future to avoid this kind of error
5. qwen vl is broken by vllm-project/vllm#27764.
We'll remove model files in the future

Known issue:
1. ray doesn't work 
2. the accuracy of qwen3-next is not correct
3. qwen3-vl is broken
4. prefix cache+ ascend scheduler + deepseek v2 lite is broken.

Co-authored-by: MengqingCao <cmq0113@163.com>
Co-authored-by: hfadzxy <starmoon_zhang@163.com>
Co-authored-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: 22dimensions <waitingwind@foxmail.com>
Co-authored-by: shen-shanshan <467638484@qq.com>


- vLLM version: v0.11.2

---------

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: hfadzxy <starmoon_zhang@163.com>
Signed-off-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: MengqingCao <cmq0113@163.com>
Co-authored-by: hfadzxy <starmoon_zhang@163.com>
Co-authored-by: leo-pony <nengjunma@outlook.com>
Kurumi5210 pushed a commit to lidenghui1110/vllm-ascend that referenced this pull request Nov 26, 2025
Bump vLLM version to v0.11.2

What's broken and changed by vLLM:
1. structured_output is broken by
vllm-project/vllm#26866
2. get_mrope_input_positions is broken by
vllm-project/vllm#28399
3. graph mode is broken by
vllm-project/vllm#25110 we'll upgrade torch to
2.8 to fix the problem later
4. embedding is broken by
vllm-project/vllm#27583
5. `get_attn_backend_cls` and attention backend is broken are broken by
vllm-project/vllm#28534
6. spec decode is broken by
vllm-project/vllm#28771
7. sp feature is broken by
vllm-project/vllm#27126
8. mtp is broken by vllm-project/vllm#27922
9. lora is broken by vllm-project/vllm#21068
10. execute_model is broken by
vllm-project/vllm#26866
11. `VLLM_DISABLE_SHARED_EXPERTS_STREAM` env is broken by
vllm-project/vllm#28159
12. kv cahe is broken by vllm-project/vllm#27753
13. dp is broken by vllm-project/vllm#25110

What's broken and changed by ourself:
1. qwen vl is broken by vllm-project/vllm#28455
We'll remove model files in the future to avoid this kind of error
2. Engine core is broken by
vllm-project/vllm#23691 We'll remove the patch
file in the future.
3. Ascend scheduler is broken by
vllm-project/vllm#28733 We'll remove ascend
scheudler later.
4. qwen3-next is broken by
vllm-project/vllm#28083 We'll remove model files
in the future to avoid this kind of error
5. qwen vl is broken by vllm-project/vllm#27764.
We'll remove model files in the future

Known issue:
1. ray doesn't work
2. the accuracy of qwen3-next is not correct
3. qwen3-vl is broken
4. prefix cache+ ascend scheduler + deepseek v2 lite is broken.

Co-authored-by: MengqingCao <cmq0113@163.com>
Co-authored-by: hfadzxy <starmoon_zhang@163.com>
Co-authored-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: 22dimensions <waitingwind@foxmail.com>
Co-authored-by: shen-shanshan <467638484@qq.com>

- vLLM version: v0.11.2

---------

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: hfadzxy <starmoon_zhang@163.com>
Signed-off-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: MengqingCao <cmq0113@163.com>
Co-authored-by: hfadzxy <starmoon_zhang@163.com>
Co-authored-by: leo-pony <nengjunma@outlook.com>
Signed-off-by: Kurumi5210 <Jaychou1620@Gmail.com>
devpatelio pushed a commit to SumanthRH/vllm that referenced this pull request Nov 29, 2025
…ct#28083)

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
845473182 pushed a commit to 845473182/vllm-ascend that referenced this pull request Nov 29, 2025
Bump vLLM version to v0.11.2

What's broken and changed by vLLM:
1. structured_output is broken by
vllm-project/vllm#26866
2. get_mrope_input_positions is broken by
vllm-project/vllm#28399
3. graph mode is broken by
vllm-project/vllm#25110 we'll upgrade torch to
2.8 to fix the problem later
4. embedding is broken by
vllm-project/vllm#27583
5. `get_attn_backend_cls` and attention backend is broken are broken by
vllm-project/vllm#28534
6. spec decode is broken by
vllm-project/vllm#28771
7. sp feature is broken by
vllm-project/vllm#27126
8. mtp is broken by vllm-project/vllm#27922
9. lora is broken by vllm-project/vllm#21068
10. execute_model is broken by
vllm-project/vllm#26866
11. `VLLM_DISABLE_SHARED_EXPERTS_STREAM` env is broken by
vllm-project/vllm#28159
12. kv cahe is broken by vllm-project/vllm#27753
13. dp is broken by vllm-project/vllm#25110

 
What's broken and changed by ourself:
1. qwen vl is broken by vllm-project/vllm#28455
We'll remove model files in the future to avoid this kind of error
2. Engine core is broken by
vllm-project/vllm#23691 We'll remove the patch
file in the future.
3. Ascend scheduler is broken by
vllm-project/vllm#28733 We'll remove ascend
scheudler later.
4. qwen3-next is broken by
vllm-project/vllm#28083 We'll remove model files
in the future to avoid this kind of error
5. qwen vl is broken by vllm-project/vllm#27764.
We'll remove model files in the future

Known issue:
1. ray doesn't work 
2. the accuracy of qwen3-next is not correct
3. qwen3-vl is broken
4. prefix cache+ ascend scheduler + deepseek v2 lite is broken.

Co-authored-by: MengqingCao <cmq0113@163.com>
Co-authored-by: hfadzxy <starmoon_zhang@163.com>
Co-authored-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: 22dimensions <waitingwind@foxmail.com>
Co-authored-by: shen-shanshan <467638484@qq.com>


- vLLM version: v0.11.2

---------

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: hfadzxy <starmoon_zhang@163.com>
Signed-off-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: MengqingCao <cmq0113@163.com>
Co-authored-by: hfadzxy <starmoon_zhang@163.com>
Co-authored-by: leo-pony <nengjunma@outlook.com>
Meihan-chen pushed a commit to Meihan-chen/vllm-ascend that referenced this pull request Dec 5, 2025
Bump vLLM version to v0.11.2

What's broken and changed by vLLM:
1. structured_output is broken by
vllm-project/vllm#26866
2. get_mrope_input_positions is broken by
vllm-project/vllm#28399
3. graph mode is broken by
vllm-project/vllm#25110 we'll upgrade torch to
2.8 to fix the problem later
4. embedding is broken by
vllm-project/vllm#27583
5. `get_attn_backend_cls` and attention backend is broken are broken by
vllm-project/vllm#28534
6. spec decode is broken by
vllm-project/vllm#28771
7. sp feature is broken by
vllm-project/vllm#27126
8. mtp is broken by vllm-project/vllm#27922
9. lora is broken by vllm-project/vllm#21068
10. execute_model is broken by
vllm-project/vllm#26866
11. `VLLM_DISABLE_SHARED_EXPERTS_STREAM` env is broken by
vllm-project/vllm#28159
12. kv cahe is broken by vllm-project/vllm#27753
13. dp is broken by vllm-project/vllm#25110

 
What's broken and changed by ourself:
1. qwen vl is broken by vllm-project/vllm#28455
We'll remove model files in the future to avoid this kind of error
2. Engine core is broken by
vllm-project/vllm#23691 We'll remove the patch
file in the future.
3. Ascend scheduler is broken by
vllm-project/vllm#28733 We'll remove ascend
scheudler later.
4. qwen3-next is broken by
vllm-project/vllm#28083 We'll remove model files
in the future to avoid this kind of error
5. qwen vl is broken by vllm-project/vllm#27764.
We'll remove model files in the future

Known issue:
1. ray doesn't work 
2. the accuracy of qwen3-next is not correct
3. qwen3-vl is broken
4. prefix cache+ ascend scheduler + deepseek v2 lite is broken.

Co-authored-by: MengqingCao <cmq0113@163.com>
Co-authored-by: hfadzxy <starmoon_zhang@163.com>
Co-authored-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: 22dimensions <waitingwind@foxmail.com>
Co-authored-by: shen-shanshan <467638484@qq.com>


- vLLM version: v0.11.2

---------

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: hfadzxy <starmoon_zhang@163.com>
Signed-off-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: MengqingCao <cmq0113@163.com>
Co-authored-by: hfadzxy <starmoon_zhang@163.com>
Co-authored-by: leo-pony <nengjunma@outlook.com>
Clorist33 pushed a commit to Clorist33/vllm-ascend that referenced this pull request Dec 9, 2025
Bump vLLM version to v0.11.2

What's broken and changed by vLLM:
1. structured_output is broken by
vllm-project/vllm#26866
2. get_mrope_input_positions is broken by
vllm-project/vllm#28399
3. graph mode is broken by
vllm-project/vllm#25110 we'll upgrade torch to
2.8 to fix the problem later
4. embedding is broken by
vllm-project/vllm#27583
5. `get_attn_backend_cls` and attention backend is broken are broken by
vllm-project/vllm#28534
6. spec decode is broken by
vllm-project/vllm#28771
7. sp feature is broken by
vllm-project/vllm#27126
8. mtp is broken by vllm-project/vllm#27922
9. lora is broken by vllm-project/vllm#21068
10. execute_model is broken by
vllm-project/vllm#26866
11. `VLLM_DISABLE_SHARED_EXPERTS_STREAM` env is broken by
vllm-project/vllm#28159
12. kv cahe is broken by vllm-project/vllm#27753
13. dp is broken by vllm-project/vllm#25110

What's broken and changed by ourself:
1. qwen vl is broken by vllm-project/vllm#28455
We'll remove model files in the future to avoid this kind of error
2. Engine core is broken by
vllm-project/vllm#23691 We'll remove the patch
file in the future.
3. Ascend scheduler is broken by
vllm-project/vllm#28733 We'll remove ascend
scheudler later.
4. qwen3-next is broken by
vllm-project/vllm#28083 We'll remove model files
in the future to avoid this kind of error
5. qwen vl is broken by vllm-project/vllm#27764.
We'll remove model files in the future

Known issue:
1. ray doesn't work
2. the accuracy of qwen3-next is not correct
3. qwen3-vl is broken
4. prefix cache+ ascend scheduler + deepseek v2 lite is broken.

Co-authored-by: MengqingCao <cmq0113@163.com>
Co-authored-by: hfadzxy <starmoon_zhang@163.com>
Co-authored-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: 22dimensions <waitingwind@foxmail.com>
Co-authored-by: shen-shanshan <467638484@qq.com>

- vLLM version: v0.11.2

---------

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: hfadzxy <starmoon_zhang@163.com>
Signed-off-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: MengqingCao <cmq0113@163.com>
Co-authored-by: hfadzxy <starmoon_zhang@163.com>
Co-authored-by: leo-pony <nengjunma@outlook.com>
Signed-off-by: tanqingshan (A) <50050625@china.huawei.com>
Clorist33 pushed a commit to Clorist33/vllm-ascend that referenced this pull request Dec 10, 2025
Bump vLLM version to v0.11.2

What's broken and changed by vLLM:
1. structured_output is broken by
vllm-project/vllm#26866
2. get_mrope_input_positions is broken by
vllm-project/vllm#28399
3. graph mode is broken by
vllm-project/vllm#25110 we'll upgrade torch to
2.8 to fix the problem later
4. embedding is broken by
vllm-project/vllm#27583
5. `get_attn_backend_cls` and attention backend is broken are broken by
vllm-project/vllm#28534
6. spec decode is broken by
vllm-project/vllm#28771
7. sp feature is broken by
vllm-project/vllm#27126
8. mtp is broken by vllm-project/vllm#27922
9. lora is broken by vllm-project/vllm#21068
10. execute_model is broken by
vllm-project/vllm#26866
11. `VLLM_DISABLE_SHARED_EXPERTS_STREAM` env is broken by
vllm-project/vllm#28159
12. kv cahe is broken by vllm-project/vllm#27753
13. dp is broken by vllm-project/vllm#25110

 
What's broken and changed by ourself:
1. qwen vl is broken by vllm-project/vllm#28455
We'll remove model files in the future to avoid this kind of error
2. Engine core is broken by
vllm-project/vllm#23691 We'll remove the patch
file in the future.
3. Ascend scheduler is broken by
vllm-project/vllm#28733 We'll remove ascend
scheudler later.
4. qwen3-next is broken by
vllm-project/vllm#28083 We'll remove model files
in the future to avoid this kind of error
5. qwen vl is broken by vllm-project/vllm#27764.
We'll remove model files in the future

Known issue:
1. ray doesn't work 
2. the accuracy of qwen3-next is not correct
3. qwen3-vl is broken
4. prefix cache+ ascend scheduler + deepseek v2 lite is broken.

Co-authored-by: MengqingCao <cmq0113@163.com>
Co-authored-by: hfadzxy <starmoon_zhang@163.com>
Co-authored-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: 22dimensions <waitingwind@foxmail.com>
Co-authored-by: shen-shanshan <467638484@qq.com>


- vLLM version: v0.11.2

---------

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: hfadzxy <starmoon_zhang@163.com>
Signed-off-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: MengqingCao <cmq0113@163.com>
Co-authored-by: hfadzxy <starmoon_zhang@163.com>
Co-authored-by: leo-pony <nengjunma@outlook.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants