Skip to content

[Perf] Optimize Sampler Redundant Copy for Model Runner v2, 1.8% Throughput Improvement#35214

Open
yewentao256 wants to merge 7 commits into
mainfrom
wentao-optimize-model-runner-v2-sampler
Open

[Perf] Optimize Sampler Redundant Copy for Model Runner v2, 1.8% Throughput Improvement#35214
yewentao256 wants to merge 7 commits into
mainfrom
wentao-optimize-model-runner-v2-sampler

Conversation

@yewentao256

@yewentao256 yewentao256 commented Feb 24, 2026

Copy link
Copy Markdown
Member

Purpose

Part of #35335

We don't need to do an expensive copy each time we call sampler, this copy is only needed in some special user config.

This PR optimized the logic

Test

export MODEL="Qwen/Qwen3-30B-A3B-Thinking-2507-FP8"
export VLLM_USE_V2_MODEL_RUNNER=1
vllm serve $MODEL --port 9256 --enable-expert-parallel

Acc

lm_eval --model local-completions --model_args "base_url=http://127.0.0.1:9256/v1/completions,model=$MODEL,num_concurrent=1024" --tasks gsm8k
# now
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match||0.6687|±  |0.0130|
|     |       |strict-match    |     5|exact_match||0.7862|±  |0.0113|
# main
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match||0.6672|±  |0.0130|
|     |       |strict-match    |     5|exact_match||0.7870|±  |0.0113|

Perf

vllm bench serve --model $MODEL  --dataset-name random --host 127.0.0.1 --port 9256 --random-input-len 2 --random-output-len 512 --request-rate inf --num-prompts 128 --num-warmups 16
# now
============ Serving Benchmark Result ============
Successful requests:                     128       
Failed requests:                         0         
Benchmark duration (s):                  22.10     
Total input tokens:                      256       
Total generated tokens:                  65536     
Request throughput (req/s):              5.79      
Output token throughput (tok/s):         2965.89   
Peak output token throughput (tok/s):    3072.00   
Peak concurrent requests:                128.00    
Total token throughput (tok/s):          2977.48   
---------------Time to First Token----------------
Mean TTFT (ms):                          190.68    
Median TTFT (ms):                        190.24    
P99 TTFT (ms):                           202.37    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          42.84     
Median TPOT (ms):                        42.84     
P99 TPOT (ms):                           42.91     
---------------Inter-token Latency----------------
Mean ITL (ms):                           42.84     
Median ITL (ms):                         42.63     
P99 ITL (ms):                            45.93     
==================================================
# main
============ Serving Benchmark Result ============
Successful requests:                     128       
Failed requests:                         0         
Benchmark duration (s):                  22.51     
Total input tokens:                      256       
Total generated tokens:                  65536     
Request throughput (req/s):              5.69      
Output token throughput (tok/s):         2911.32   
Peak output token throughput (tok/s):    2944.00   
Peak concurrent requests:                128.00    
Total token throughput (tok/s):          2922.70   
---------------Time to First Token----------------
Mean TTFT (ms):                          190.54    
Median TTFT (ms):                        190.16    
P99 TTFT (ms):                           201.62    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          43.65     
Median TPOT (ms):                        43.65     
P99 TPOT (ms):                           43.72     
---------------Inter-token Latency----------------
Mean ITL (ms):                           43.65     
Median ITL (ms):                         43.45     
P99 ITL (ms):                            46.95     
==================================================

Signed-off-by: yewentao256 <zhyanwentao@126.com>
@mergify mergify Bot added the v1 label Feb 24, 2026

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a performance optimization in the Sampler by avoiding a redundant copy of the logits tensor when no processing is required. This is achieved by adding a new helper method, _needs_logits_processing, which checks if any logit modifications are necessary for the current batch. The expensive copy and subsequent processing are now conditionally executed, which should improve throughput in cases where no special sampling parameters are used. The implementation appears correct and effectively delivers the intended optimization. I have reviewed the changes and found no issues.

@yewentao256 yewentao256 added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 25, 2026
@mergify

mergify Bot commented Mar 3, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @yewentao256.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Mar 3, 2026
Signed-off-by: yewentao256 <zhyanwentao@126.com>
@yewentao256 yewentao256 requested a review from njhill as a code owner March 4, 2026 19:25
@mergify mergify Bot removed the needs-rebase label Mar 4, 2026
Signed-off-by: yewentao256 <zhyanwentao@126.com>
@yewentao256

yewentao256 commented Mar 4, 2026

Copy link
Copy Markdown
Member Author

Hi @WoosukKwon , if you prefer a less diff version, see commit

Current we build a single per-batch logits-processing plan in the caller, remove redundant checks from the callee sampling ops, and keep behavior unchanged while improving readability and reducing repeated condition evaluation; @WoosukKwon could you please help review?

@mergify

mergify Bot commented Mar 4, 2026

Copy link
Copy Markdown
Contributor

Hi @yewentao256, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@mergify

mergify Bot commented Mar 11, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @yewentao256.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Mar 11, 2026
Signed-off-by: yewentao256 <zhyanwentao@126.com>
@mergify mergify Bot removed the needs-rebase label Mar 16, 2026
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant