Skip to content

[Core] Whisper support torch.compile#30385

Merged
NickLucche merged 4 commits into
vllm-project:mainfrom
NickLucche:whisper-compile
Jan 19, 2026
Merged

[Core] Whisper support torch.compile#30385
NickLucche merged 4 commits into
vllm-project:mainfrom
NickLucche:whisper-compile

Conversation

@NickLucche

@NickLucche NickLucche commented Dec 10, 2025

Copy link
Copy Markdown
Member

This PR is yet another Whisper performance optimization, adding support for torch.compile during decoding step.
It follows a very similar approach to #30072 (and should also land after that to ensure best defaults) in which only the 2nd decoder steps onward are compiled.
This is due to the fact that step0 in enc-dec models computes and caches crossattn KVs, requiring encoder_output as additional input and hence generating a different graph from the other steps.

Updated profiling:
image

Considerations

I've attempted to add the "2nd decoder step" selection logic directly in the model runner (_model_forward).
I am aware that _model_forward is currently used by OOT runners (#25084), although no "official" runner interface contract is maintained to ensure compatibility (unlike connectors to name one), which makes maintaining these kinds of methods without breaking external usage quite hairy.
As this change may affect those runners I am also pinging @patrick-toulme .

Happy to change to a less invasive logic if we find a cleaner way to do it @LucasWilkinson .
Other options are adding the flag to attn_metadata and then retrieving the metadata from the support_torch_compile wrapper OR hacking the WhisperDecoder __call__ (definitely not nice).

UPDATE:
Following @ProExpertProg suggestion, I've moved to implementing the alternative option in which the skip_compiled logic is generically handled inside the compile decorator.
So no change to _model_forward is actually needed.

Related PRs #29421 #30072

Test with

Compilation is enabled by default:

vllm serve openai/whisper-large-v3-turbo

cc @DarkLight1337 @robertgshaw2-redhat

@chatgpt-codex-connector

Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

@mergify mergify Bot added the v1 label Dec 10, 2025

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces torch.compile support for the Whisper model's decoder, aimed at improving performance during the decoding phase. The implementation cleverly compiles only the decoding steps from the second step onwards, correctly identifying that the first step has a different computation graph due to cross-attention key-value cache generation. This is achieved by adding a force_eager flag to the _model_forward method in GPUModelRunner, which is conditionally set based on the presence of encoder inputs. The changes are well-designed, backward-compatible, and the generic approach in GPUModelRunner could be beneficial for other encoder-decoder models in the future. The code appears to be correct and I could not identify any issues of high or critical severity.

@DarkLight1337 DarkLight1337 left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should have at least one test that uses Whisper with CUDA graph

@patrick-toulme

Copy link
Copy Markdown
Contributor

Changes look fine to me. All you are doing is adding a gated option to run eager mode in model_forward. Any downstream consumers who are subclassing just have to add that variable now. LGTM

@robertgshaw2-redhat

Copy link
Copy Markdown
Collaborator

note: this generates incorrect answers

@mergify

mergify Bot commented Dec 18, 2025

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @NickLucche.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Comment thread vllm/model_executor/models/whisper.py
@NickLucche

Copy link
Copy Markdown
Member Author

@robertgshaw2-redhat reviving this PR, are there any blockers? Happy to look at accuracy issues if any, can't spot them from usual tests

@jikunshang

Copy link
Copy Markdown
Member

a noob question: I noticed that this only enable Decoder part. Is there any blocker to enable torch.compile + Encoder part? just like mm + compile suport https://docs.vllm.ai/en/latest/design/torch_compile_multimodal/

@NickLucche

Copy link
Copy Markdown
Member Author

@jikunshang #30549

@DarkLight1337

Copy link
Copy Markdown
Member

Let's wait for @robertgshaw2-redhat to elaborate on the accuracy issues first

Comment thread vllm/v1/worker/gpu_model_runner.py Outdated
@@ -2919,6 +2920,17 @@ def _model_forward(
Returns:
Model output tensor
"""

if force_eager:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just add force_eager to ForwardContext and read that in the compile decorator?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is an option for enable_if in the support_torch_compile decorator - perhaps we can leverage that?

see

enable_if=should_torch_compile_mm_vit,
for example usage

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that flag is for optional compilation, here I need to always compile, but optionally do eager (aka not call the compiled graph)

@NickLucche

Copy link
Copy Markdown
Member Author

I've implemented @ProExpertProg suggested approach and updated the description.

@NickLucche NickLucche enabled auto-merge (squash) January 15, 2026 10:31
@github-actions github-actions Bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jan 15, 2026
@@ -156,7 +156,9 @@ def test_wer_correctness(
model_name, dataset_repo, expected_wer, n_examples=-1, max_concurrent_request=None
):
# TODO refactor to use `ASRDataset`
with RemoteOpenAIServer(model_name, ["--enforce-eager"]) as remote_server:
with RemoteOpenAIServer(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-cc.mode=NONE if you want to just disable compilation but keep cuda graphs

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will follow up with more cg+compilation tests (which is the default vllm serve setup)

Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
@NickLucche NickLucche merged commit 74c583b into vllm-project:main Jan 19, 2026
58 checks passed
gopalsarda pushed a commit to gopalsarda/vllm that referenced this pull request Jan 20, 2026
Signed-off-by: NickLucche <nlucches@redhat.com>
wangxiyuan pushed a commit to vllm-project/vllm-ascend that referenced this pull request Jan 27, 2026
### What this PR does / why we need it?
1. ✅ Upgrade vllm commit to: 0115
(8471b27)
Modify import paths due to the refactors:
vllm-project/vllm#32245
vllm-project/vllm#32060
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21034239336/job/60490156965?pr=5913
2. ✅Upgrade vllm commit to: 0119
(9a1f16d)
Fix `WorkerProc.__init__() missing 1 required positional argument:
'is_driver_worker'` due to
vllm-project/vllm#28506
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21156263050/job/60841668755?5569
3. ✅Upgrade vllm commit to:
0120(148117e)
1. Add `skip_compiled` param in `set_forward_context` due to
vllm-project/vllm#30385
2. Modify `tests/ut/spec_decode/test_eagle_proposer.py` due to
vllm-project/vllm#24322
change `self.max_num_tokens =
vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size`
3. Modify UT import paths due to the
refactors:vllm-project/vllm#32060
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21204851770/job/60999046946
4. ✅Upgrade vllm commit to:
0121(f23fb5a)
1. vLLM switched `uses_mrope` from target to draft model config, making
`positions`/`mrope_positions` mutually exclusive, breaking vllm-ascend's
direct self.positions access and tests missing
`draft_model_config.uses_mrope`.
vllm-project/vllm#32048
2. Moved bs_to_padded_graph_size from CompilationConfig to
CudagraphDispatcher due to the refactor
vllm-project/vllm#30143
3. Remove unused `maybe_setup_kv_connector` due to
vllm-project/vllm#32077
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21217728738/job/61043738834
6. ✅Upgrade vllm commit to:
0122(8ebf271)
Updating FusedMoEParallelConfig (added enable_eplb) and FusedMoEConfig
due to vllm-project/vllm#32414
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21249922546/job/61148613054
8. ✅Upgrade vllm commit to:
0123(dc917cc)
Setting temperature=0.0 due to the removal of the default temperature
value in vllm-project/vllm#32723
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21280796875
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.14.0
- vLLM main:
vllm-project/vllm@d682094

---------

Signed-off-by: wjunLu <wjunlu217@gmail.com>
Signed-off-by: Meihan-chen <jcccx.cmh@gmail.com>
Co-authored-by: wjunLu <wjunlu217@gmail.com>
starmountain1997 pushed a commit to starmountain1997/vllm-ascend that referenced this pull request Jan 31, 2026
### What this PR does / why we need it?
1. ✅ Upgrade vllm commit to: 0115
(8471b27)
Modify import paths due to the refactors:
vllm-project/vllm#32245
vllm-project/vllm#32060
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21034239336/job/60490156965?pr=5913
2. ✅Upgrade vllm commit to: 0119
(9a1f16d)
Fix `WorkerProc.__init__() missing 1 required positional argument:
'is_driver_worker'` due to
vllm-project/vllm#28506
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21156263050/job/60841668755?5569
3. ✅Upgrade vllm commit to:
0120(148117e)
1. Add `skip_compiled` param in `set_forward_context` due to
vllm-project/vllm#30385
2. Modify `tests/ut/spec_decode/test_eagle_proposer.py` due to
vllm-project/vllm#24322
change `self.max_num_tokens =
vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size`
3. Modify UT import paths due to the
refactors:vllm-project/vllm#32060
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21204851770/job/60999046946
4. ✅Upgrade vllm commit to:
0121(f23fb5a)
1. vLLM switched `uses_mrope` from target to draft model config, making
`positions`/`mrope_positions` mutually exclusive, breaking vllm-ascend's
direct self.positions access and tests missing
`draft_model_config.uses_mrope`.
vllm-project/vllm#32048
2. Moved bs_to_padded_graph_size from CompilationConfig to
CudagraphDispatcher due to the refactor
vllm-project/vllm#30143
3. Remove unused `maybe_setup_kv_connector` due to
vllm-project/vllm#32077
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21217728738/job/61043738834
6. ✅Upgrade vllm commit to:
0122(8ebf271)
Updating FusedMoEParallelConfig (added enable_eplb) and FusedMoEConfig
due to vllm-project/vllm#32414
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21249922546/job/61148613054
8. ✅Upgrade vllm commit to:
0123(dc917cc)
Setting temperature=0.0 due to the removal of the default temperature
value in vllm-project/vllm#32723
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21280796875
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.14.0
- vLLM main:
vllm-project/vllm@d682094

---------

Signed-off-by: wjunLu <wjunlu217@gmail.com>
Signed-off-by: Meihan-chen <jcccx.cmh@gmail.com>
Co-authored-by: wjunLu <wjunlu217@gmail.com>
starmountain1997 pushed a commit to starmountain1997/vllm-ascend that referenced this pull request Jan 31, 2026
### What this PR does / why we need it?
1. ✅ Upgrade vllm commit to: 0115
(8471b27)
Modify import paths due to the refactors:
vllm-project/vllm#32245
vllm-project/vllm#32060
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21034239336/job/60490156965?pr=5913
2. ✅Upgrade vllm commit to: 0119
(9a1f16d)
Fix `WorkerProc.__init__() missing 1 required positional argument:
'is_driver_worker'` due to
vllm-project/vllm#28506
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21156263050/job/60841668755?5569
3. ✅Upgrade vllm commit to:
0120(148117e)
1. Add `skip_compiled` param in `set_forward_context` due to
vllm-project/vllm#30385
2. Modify `tests/ut/spec_decode/test_eagle_proposer.py` due to
vllm-project/vllm#24322
change `self.max_num_tokens =
vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size`
3. Modify UT import paths due to the
refactors:vllm-project/vllm#32060
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21204851770/job/60999046946
4. ✅Upgrade vllm commit to:
0121(f23fb5a)
1. vLLM switched `uses_mrope` from target to draft model config, making
`positions`/`mrope_positions` mutually exclusive, breaking vllm-ascend's
direct self.positions access and tests missing
`draft_model_config.uses_mrope`.
vllm-project/vllm#32048
2. Moved bs_to_padded_graph_size from CompilationConfig to
CudagraphDispatcher due to the refactor
vllm-project/vllm#30143
3. Remove unused `maybe_setup_kv_connector` due to
vllm-project/vllm#32077
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21217728738/job/61043738834
6. ✅Upgrade vllm commit to:
0122(8ebf271)
Updating FusedMoEParallelConfig (added enable_eplb) and FusedMoEConfig
due to vllm-project/vllm#32414
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21249922546/job/61148613054
8. ✅Upgrade vllm commit to:
0123(dc917cc)
Setting temperature=0.0 due to the removal of the default temperature
value in vllm-project/vllm#32723
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21280796875
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.14.0
- vLLM main:
vllm-project/vllm@d682094

---------

Signed-off-by: wjunLu <wjunlu217@gmail.com>
Signed-off-by: Meihan-chen <jcccx.cmh@gmail.com>
Co-authored-by: wjunLu <wjunlu217@gmail.com>
chenchuw886 pushed a commit to chenchuw886/vllm-ascend that referenced this pull request Feb 12, 2026
### What this PR does / why we need it?
1. ✅ Upgrade vllm commit to: 0115
(8471b27)
Modify import paths due to the refactors:
vllm-project/vllm#32245
vllm-project/vllm#32060
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21034239336/job/60490156965?pr=5913
2. ✅Upgrade vllm commit to: 0119
(9a1f16d)
Fix `WorkerProc.__init__() missing 1 required positional argument:
'is_driver_worker'` due to
vllm-project/vllm#28506
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21156263050/job/60841668755?5569
3. ✅Upgrade vllm commit to:
0120(148117e)
1. Add `skip_compiled` param in `set_forward_context` due to
vllm-project/vllm#30385
2. Modify `tests/ut/spec_decode/test_eagle_proposer.py` due to
vllm-project/vllm#24322
change `self.max_num_tokens =
vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size`
3. Modify UT import paths due to the
refactors:vllm-project/vllm#32060
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21204851770/job/60999046946
4. ✅Upgrade vllm commit to:
0121(f23fb5a)
1. vLLM switched `uses_mrope` from target to draft model config, making
`positions`/`mrope_positions` mutually exclusive, breaking vllm-ascend's
direct self.positions access and tests missing
`draft_model_config.uses_mrope`.
vllm-project/vllm#32048
2. Moved bs_to_padded_graph_size from CompilationConfig to
CudagraphDispatcher due to the refactor
vllm-project/vllm#30143
3. Remove unused `maybe_setup_kv_connector` due to
vllm-project/vllm#32077
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21217728738/job/61043738834
6. ✅Upgrade vllm commit to:
0122(8ebf271)
Updating FusedMoEParallelConfig (added enable_eplb) and FusedMoEConfig
due to vllm-project/vllm#32414
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21249922546/job/61148613054
8. ✅Upgrade vllm commit to:
0123(dc917cc)
Setting temperature=0.0 due to the removal of the default temperature
value in vllm-project/vllm#32723
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21280796875
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.14.0
- vLLM main:
vllm-project/vllm@d682094

---------

Signed-off-by: wjunLu <wjunlu217@gmail.com>
Signed-off-by: Meihan-chen <jcccx.cmh@gmail.com>
Co-authored-by: wjunLu <wjunlu217@gmail.com>
Signed-off-by: momochenchuw <chenchuw@huawei.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Feb 28, 2026
### What this PR does / why we need it?
1. ✅ Upgrade vllm commit to: 0115
(8471b27)
Modify import paths due to the refactors:
vllm-project/vllm#32245
vllm-project/vllm#32060
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21034239336/job/60490156965?pr=5913
2. ✅Upgrade vllm commit to: 0119
(9a1f16d)
Fix `WorkerProc.__init__() missing 1 required positional argument:
'is_driver_worker'` due to
vllm-project/vllm#28506
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21156263050/job/60841668755?5569
3. ✅Upgrade vllm commit to:
0120(148117e)
1. Add `skip_compiled` param in `set_forward_context` due to
vllm-project/vllm#30385
2. Modify `tests/ut/spec_decode/test_eagle_proposer.py` due to
vllm-project/vllm#24322
change `self.max_num_tokens =
vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size`
3. Modify UT import paths due to the
refactors:vllm-project/vllm#32060
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21204851770/job/60999046946
4. ✅Upgrade vllm commit to:
0121(f23fb5a)
1. vLLM switched `uses_mrope` from target to draft model config, making
`positions`/`mrope_positions` mutually exclusive, breaking vllm-ascend's
direct self.positions access and tests missing
`draft_model_config.uses_mrope`.
vllm-project/vllm#32048
2. Moved bs_to_padded_graph_size from CompilationConfig to
CudagraphDispatcher due to the refactor
vllm-project/vllm#30143
3. Remove unused `maybe_setup_kv_connector` due to
vllm-project/vllm#32077
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21217728738/job/61043738834
6. ✅Upgrade vllm commit to:
0122(8ebf271)
Updating FusedMoEParallelConfig (added enable_eplb) and FusedMoEConfig
due to vllm-project/vllm#32414
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21249922546/job/61148613054
8. ✅Upgrade vllm commit to:
0123(dc917cc)
Setting temperature=0.0 due to the removal of the default temperature
value in vllm-project/vllm#32723
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21280796875
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.14.0
- vLLM main:
vllm-project/vllm@d682094

---------

Signed-off-by: wjunLu <wjunlu217@gmail.com>
Signed-off-by: Meihan-chen <jcccx.cmh@gmail.com>
Co-authored-by: wjunLu <wjunlu217@gmail.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
maoxx241 pushed a commit to maoxx241/vllm-ascend that referenced this pull request Mar 2, 2026
### What this PR does / why we need it?
1. ✅ Upgrade vllm commit to: 0115
(8471b27)
Modify import paths due to the refactors:
vllm-project/vllm#32245
vllm-project/vllm#32060
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21034239336/job/60490156965?pr=5913
2. ✅Upgrade vllm commit to: 0119
(9a1f16d)
Fix `WorkerProc.__init__() missing 1 required positional argument:
'is_driver_worker'` due to
vllm-project/vllm#28506
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21156263050/job/60841668755?5569
3. ✅Upgrade vllm commit to:
0120(148117e)
1. Add `skip_compiled` param in `set_forward_context` due to
vllm-project/vllm#30385
2. Modify `tests/ut/spec_decode/test_eagle_proposer.py` due to
vllm-project/vllm#24322
change `self.max_num_tokens =
vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size`
3. Modify UT import paths due to the
refactors:vllm-project/vllm#32060
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21204851770/job/60999046946
4. ✅Upgrade vllm commit to:
0121(f23fb5a)
1. vLLM switched `uses_mrope` from target to draft model config, making
`positions`/`mrope_positions` mutually exclusive, breaking vllm-ascend's
direct self.positions access and tests missing
`draft_model_config.uses_mrope`.
vllm-project/vllm#32048
2. Moved bs_to_padded_graph_size from CompilationConfig to
CudagraphDispatcher due to the refactor
vllm-project/vllm#30143
3. Remove unused `maybe_setup_kv_connector` due to
vllm-project/vllm#32077
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21217728738/job/61043738834
6. ✅Upgrade vllm commit to:
0122(8ebf271)
Updating FusedMoEParallelConfig (added enable_eplb) and FusedMoEConfig
due to vllm-project/vllm#32414
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21249922546/job/61148613054
8. ✅Upgrade vllm commit to:
0123(dc917cc)
Setting temperature=0.0 due to the removal of the default temperature
value in vllm-project/vllm#32723
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21280796875
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.14.0
- vLLM main:
vllm-project/vllm@d682094

---------

Signed-off-by: wjunLu <wjunlu217@gmail.com>
Signed-off-by: Meihan-chen <jcccx.cmh@gmail.com>
Co-authored-by: wjunLu <wjunlu217@gmail.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Mar 4, 2026
### What this PR does / why we need it?
1. ✅ Upgrade vllm commit to: 0115
(8471b27)
Modify import paths due to the refactors:
vllm-project/vllm#32245
vllm-project/vllm#32060
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21034239336/job/60490156965?pr=5913
2. ✅Upgrade vllm commit to: 0119
(9a1f16d)
Fix `WorkerProc.__init__() missing 1 required positional argument:
'is_driver_worker'` due to
vllm-project/vllm#28506
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21156263050/job/60841668755?5569
3. ✅Upgrade vllm commit to:
0120(148117e)
1. Add `skip_compiled` param in `set_forward_context` due to
vllm-project/vllm#30385
2. Modify `tests/ut/spec_decode/test_eagle_proposer.py` due to
vllm-project/vllm#24322
change `self.max_num_tokens =
vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size`
3. Modify UT import paths due to the
refactors:vllm-project/vllm#32060
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21204851770/job/60999046946
4. ✅Upgrade vllm commit to:
0121(f23fb5a)
1. vLLM switched `uses_mrope` from target to draft model config, making
`positions`/`mrope_positions` mutually exclusive, breaking vllm-ascend's
direct self.positions access and tests missing
`draft_model_config.uses_mrope`.
vllm-project/vllm#32048
2. Moved bs_to_padded_graph_size from CompilationConfig to
CudagraphDispatcher due to the refactor
vllm-project/vllm#30143
3. Remove unused `maybe_setup_kv_connector` due to
vllm-project/vllm#32077
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21217728738/job/61043738834
6. ✅Upgrade vllm commit to:
0122(8ebf271)
Updating FusedMoEParallelConfig (added enable_eplb) and FusedMoEConfig
due to vllm-project/vllm#32414
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21249922546/job/61148613054
8. ✅Upgrade vllm commit to:
0123(dc917cc)
Setting temperature=0.0 due to the removal of the default temperature
value in vllm-project/vllm#32723
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21280796875
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.14.0
- vLLM main:
vllm-project/vllm@d682094

---------

Signed-off-by: wjunLu <wjunlu217@gmail.com>
Signed-off-by: Meihan-chen <jcccx.cmh@gmail.com>
Co-authored-by: wjunLu <wjunlu217@gmail.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
LCAIZJ pushed a commit to LCAIZJ/vllm-ascend that referenced this pull request Mar 7, 2026
### What this PR does / why we need it?
1. ✅ Upgrade vllm commit to: 0115
(8471b27)
Modify import paths due to the refactors:
vllm-project/vllm#32245
vllm-project/vllm#32060
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21034239336/job/60490156965?pr=5913
2. ✅Upgrade vllm commit to: 0119
(9a1f16d)
Fix `WorkerProc.__init__() missing 1 required positional argument:
'is_driver_worker'` due to
vllm-project/vllm#28506
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21156263050/job/60841668755?5569
3. ✅Upgrade vllm commit to:
0120(148117e)
1. Add `skip_compiled` param in `set_forward_context` due to
vllm-project/vllm#30385
2. Modify `tests/ut/spec_decode/test_eagle_proposer.py` due to
vllm-project/vllm#24322
change `self.max_num_tokens =
vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size`
3. Modify UT import paths due to the
refactors:vllm-project/vllm#32060
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21204851770/job/60999046946
4. ✅Upgrade vllm commit to:
0121(f23fb5a)
1. vLLM switched `uses_mrope` from target to draft model config, making
`positions`/`mrope_positions` mutually exclusive, breaking vllm-ascend's
direct self.positions access and tests missing
`draft_model_config.uses_mrope`.
vllm-project/vllm#32048
2. Moved bs_to_padded_graph_size from CompilationConfig to
CudagraphDispatcher due to the refactor
vllm-project/vllm#30143
3. Remove unused `maybe_setup_kv_connector` due to
vllm-project/vllm#32077
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21217728738/job/61043738834
6. ✅Upgrade vllm commit to:
0122(8ebf271)
Updating FusedMoEParallelConfig (added enable_eplb) and FusedMoEConfig
due to vllm-project/vllm#32414
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21249922546/job/61148613054
8. ✅Upgrade vllm commit to:
0123(dc917cc)
Setting temperature=0.0 due to the removal of the default temperature
value in vllm-project/vllm#32723
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21280796875
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.14.0
- vLLM main:
vllm-project/vllm@d682094

---------

Signed-off-by: wjunLu <wjunlu217@gmail.com>
Signed-off-by: Meihan-chen <jcccx.cmh@gmail.com>
Co-authored-by: wjunLu <wjunlu217@gmail.com>
jiangyunfan1 pushed a commit to jiangyunfan1/vllm-ascend that referenced this pull request Apr 9, 2026
### What this PR does / why we need it?
1. ✅ Upgrade vllm commit to: 0115
(8471b27)
Modify import paths due to the refactors:
vllm-project/vllm#32245
vllm-project/vllm#32060
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21034239336/job/60490156965?pr=5913
2. ✅Upgrade vllm commit to: 0119
(9a1f16d)
Fix `WorkerProc.__init__() missing 1 required positional argument:
'is_driver_worker'` due to
vllm-project/vllm#28506
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21156263050/job/60841668755?5569
3. ✅Upgrade vllm commit to:
0120(148117e)
1. Add `skip_compiled` param in `set_forward_context` due to
vllm-project/vllm#30385
2. Modify `tests/ut/spec_decode/test_eagle_proposer.py` due to
vllm-project/vllm#24322
change `self.max_num_tokens =
vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size`
3. Modify UT import paths due to the
refactors:vllm-project/vllm#32060
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21204851770/job/60999046946
4. ✅Upgrade vllm commit to:
0121(f23fb5a)
1. vLLM switched `uses_mrope` from target to draft model config, making
`positions`/`mrope_positions` mutually exclusive, breaking vllm-ascend's
direct self.positions access and tests missing
`draft_model_config.uses_mrope`.
vllm-project/vllm#32048
2. Moved bs_to_padded_graph_size from CompilationConfig to
CudagraphDispatcher due to the refactor
vllm-project/vllm#30143
3. Remove unused `maybe_setup_kv_connector` due to
vllm-project/vllm#32077
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21217728738/job/61043738834
6. ✅Upgrade vllm commit to:
0122(8ebf271)
Updating FusedMoEParallelConfig (added enable_eplb) and FusedMoEConfig
due to vllm-project/vllm#32414
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21249922546/job/61148613054
8. ✅Upgrade vllm commit to:
0123(dc917cc)
Setting temperature=0.0 due to the removal of the default temperature
value in vllm-project/vllm#32723
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21280796875
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.14.0
- vLLM main:
vllm-project/vllm@d682094

---------

Signed-off-by: wjunLu <wjunlu217@gmail.com>
Signed-off-by: Meihan-chen <jcccx.cmh@gmail.com>
Co-authored-by: wjunLu <wjunlu217@gmail.com>
yangzhe-2026 pushed a commit to yangzhe-2026/vllm-ascend that referenced this pull request May 6, 2026
### What this PR does / why we need it?
1. ✅ Upgrade vllm commit to: 0115
(8471b27)
Modify import paths due to the refactors:
vllm-project/vllm#32245
vllm-project/vllm#32060
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21034239336/job/60490156965?pr=5913
2. ✅Upgrade vllm commit to: 0119
(9a1f16d)
Fix `WorkerProc.__init__() missing 1 required positional argument:
'is_driver_worker'` due to
vllm-project/vllm#28506
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21156263050/job/60841668755?5569
3. ✅Upgrade vllm commit to:
0120(148117e)
1. Add `skip_compiled` param in `set_forward_context` due to
vllm-project/vllm#30385
2. Modify `tests/ut/spec_decode/test_eagle_proposer.py` due to
vllm-project/vllm#24322
change `self.max_num_tokens =
vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size`
3. Modify UT import paths due to the
refactors:vllm-project/vllm#32060
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21204851770/job/60999046946
4. ✅Upgrade vllm commit to:
0121(f23fb5a)
1. vLLM switched `uses_mrope` from target to draft model config, making
`positions`/`mrope_positions` mutually exclusive, breaking vllm-ascend's
direct self.positions access and tests missing
`draft_model_config.uses_mrope`.
vllm-project/vllm#32048
2. Moved bs_to_padded_graph_size from CompilationConfig to
CudagraphDispatcher due to the refactor
vllm-project/vllm#30143
3. Remove unused `maybe_setup_kv_connector` due to
vllm-project/vllm#32077
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21217728738/job/61043738834
6. ✅Upgrade vllm commit to:
0122(8ebf271)
Updating FusedMoEParallelConfig (added enable_eplb) and FusedMoEConfig
due to vllm-project/vllm#32414
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21249922546/job/61148613054
8. ✅Upgrade vllm commit to:
0123(dc917cc)
Setting temperature=0.0 due to the removal of the default temperature
value in vllm-project/vllm#32723
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21280796875
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.14.0
- vLLM main:
vllm-project/vllm@d682094

---------

Signed-off-by: wjunLu <wjunlu217@gmail.com>
Signed-off-by: Meihan-chen <jcccx.cmh@gmail.com>
Co-authored-by: wjunLu <wjunlu217@gmail.com>
mystous pushed a commit to mystous/vllm_hybrid that referenced this pull request May 10, 2026
Signed-off-by: NickLucche <nlucches@redhat.com>
nanxingMy pushed a commit to nanxingMy/vllm-ascend that referenced this pull request May 15, 2026
### What this PR does / why we need it?
1. ✅ Upgrade vllm commit to: 0115
(8471b27)
Modify import paths due to the refactors:
vllm-project/vllm#32245
vllm-project/vllm#32060
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21034239336/job/60490156965?pr=5913
2. ✅Upgrade vllm commit to: 0119
(9a1f16d)
Fix `WorkerProc.__init__() missing 1 required positional argument:
'is_driver_worker'` due to
vllm-project/vllm#28506
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21156263050/job/60841668755?5569
3. ✅Upgrade vllm commit to:
0120(148117e)
1. Add `skip_compiled` param in `set_forward_context` due to
vllm-project/vllm#30385
2. Modify `tests/ut/spec_decode/test_eagle_proposer.py` due to
vllm-project/vllm#24322
change `self.max_num_tokens =
vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size`
3. Modify UT import paths due to the
refactors:vllm-project/vllm#32060
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21204851770/job/60999046946
4. ✅Upgrade vllm commit to:
0121(f23fb5a)
1. vLLM switched `uses_mrope` from target to draft model config, making
`positions`/`mrope_positions` mutually exclusive, breaking vllm-ascend's
direct self.positions access and tests missing
`draft_model_config.uses_mrope`.
vllm-project/vllm#32048
2. Moved bs_to_padded_graph_size from CompilationConfig to
CudagraphDispatcher due to the refactor
vllm-project/vllm#30143
3. Remove unused `maybe_setup_kv_connector` due to
vllm-project/vllm#32077
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21217728738/job/61043738834
6. ✅Upgrade vllm commit to:
0122(8ebf271)
Updating FusedMoEParallelConfig (added enable_eplb) and FusedMoEConfig
due to vllm-project/vllm#32414
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21249922546/job/61148613054
8. ✅Upgrade vllm commit to:
0123(dc917cc)
Setting temperature=0.0 due to the removal of the default temperature
value in vllm-project/vllm#32723
Test result:
https://github.com/vllm-project/vllm-ascend/actions/runs/21280796875
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.14.0
- vLLM main:
vllm-project/vllm@d682094

---------

Signed-off-by: wjunLu <wjunlu217@gmail.com>
Signed-off-by: Meihan-chen <jcccx.cmh@gmail.com>
Co-authored-by: wjunLu <wjunlu217@gmail.com>
Signed-off-by: nanxing <1014662416@qq.com>
my-other-github-account pushed a commit to my-other-github-account/vllm that referenced this pull request May 15, 2026
Signed-off-by: NickLucche <nlucches@redhat.com>
my-other-github-account pushed a commit to my-other-github-account/vllm that referenced this pull request May 15, 2026
Signed-off-by: NickLucche <nlucches@redhat.com>
0826joyce pushed a commit to 0826joyce/vllm-serving-optimization that referenced this pull request May 19, 2026
Signed-off-by: NickLucche <nlucches@redhat.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants