Skip to content

[Perf] Enable cuda graph for deepepHT, 5.3% throughput improvement, 4.4% TTFT improvement#29558

Merged
ProExpertProg merged 11 commits intomainfrom
wentao-enable-cudagraph-for-deepepHT
Dec 7, 2025
Merged

[Perf] Enable cuda graph for deepepHT, 5.3% throughput improvement, 4.4% TTFT improvement#29558
ProExpertProg merged 11 commits intomainfrom
wentao-enable-cudagraph-for-deepepHT

Conversation

@yewentao256
Copy link
Copy Markdown
Member

@yewentao256 yewentao256 commented Nov 27, 2025

Purpose

We cancel the update in #25093 because of the cache hit issue, and seems already be fixed in main, now we can enable cudagraph as default for deepEPHT with the split of moe ops.

Test

vllm serve deepseek-ai/DeepSeek-V3.1 -dp 8 --enable-expert-parallel --port 9256

Acc

lm_eval --model local-completions --model_args "base_url=http://127.0.0.1:9256/v1/completions,model=deepseek-ai/DeepSeek-V3.1,num_concurrent=1024" --tasks gsm8k

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match||0.9568|±  |0.0056|
|     |       |strict-match    |     5|exact_match||0.9575|±  |0.0056|

Perf

vllm bench serve --model deepseek-ai/DeepSeek-V3.1 --dataset-name random --host 127.0.0.1 --port 9256 --random-input-len 2 --random-output-len 256 --request-rate inf --num-prompts 1024

# now
============ Serving Benchmark Result ============
Successful requests:                     1024      
Failed requests:                         0         
Benchmark duration (s):                  34.20     
Total input tokens:                      1024      
Total generated tokens:                  262144    
Request throughput (req/s):              29.94     
Output token throughput (tok/s):         7665.18   
Peak output token throughput (tok/s):    8192.00   
Peak concurrent requests:                1024.00   
Total Token throughput (tok/s):          7695.12   
---------------Time to First Token----------------
Mean TTFT (ms):                          1195.77   
Median TTFT (ms):                        1248.90   
P99 TTFT (ms):                           1272.42   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          129.00    
Median TPOT (ms):                        128.99    
P99 TPOT (ms):                           129.19    
---------------Inter-token Latency----------------
Mean ITL (ms):                           129.00    
Median ITL (ms):                         129.15    
P99 ITL (ms):                            167.09    
==================================================

# main (cuda graph be set to None)
============ Serving Benchmark Result ============
Successful requests:                     1024      
Failed requests:                         0         
Benchmark duration (s):                  36.03     
Total input tokens:                      1024      
Total generated tokens:                  262144    
Request throughput (req/s):              28.42     
Output token throughput (tok/s):         7275.19   
Peak output token throughput (tok/s):    8192.00   
Peak concurrent requests:                1024.00   
Total Token throughput (tok/s):          7303.61   
---------------Time to First Token----------------
Mean TTFT (ms):                          1248.72   
Median TTFT (ms):                        1256.96   
P99 TTFT (ms):                           1411.12   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          135.76    
Median TPOT (ms):                        135.77    
P99 TPOT (ms):                           136.01    
---------------Inter-token Latency----------------
Mean ITL (ms):                           135.76    
Median ITL (ms):                         135.62    
P99 ITL (ms):                            172.30    
==================================================

Unit test

(wentao) wentao@dgxB200-09:~/vllm-source$ pytest tests/compile/test_config.py
===================================== test session starts =====================================
platform linux -- Python 3.12.3, pytest-9.0.1, pluggy-1.6.0
rootdir: /home/wentao/vllm-source
configfile: pyproject.toml
plugins: forked-1.6.0, anyio-4.11.0, timeout-2.4.0
collected 28 items                                                                            

tests/compile/test_config.py ............................                               [100%]



tests/compile/test_config.py::test_VLLM_DISABLE_COMPILE_CACHE[1]
tests/compile/test_config.py::test_use_cudagraphs[NONE-0]
tests/compile/test_config.py::test_use_cudagraphs[FULL_DECODE_ONLY-1]
tests/compile/test_config.py::test_use_cudagraphs[PIECEWISE-13]
tests/compile/test_config.py::test_use_cudagraphs[FULL_AND_PIECEWISE-14]
tests/compile/test_config.py::test_stock_torch_compile
tests/compile/test_config.py::test_no_compilation
tests/compile/test_config.py::test_enforce_eager
  /home/wentao/.venv/lib/python3.12/site-packages/py/_process/forkedfunc.py:45: DeprecationWarning: This process (pid=3992609) is multi-threaded, use of fork() may lead to deadlocks in the child.
    pid = os.fork()

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
========================= 28 passed, 10 warnings in 113.14s (0:01:53) ========================

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request re-enables CUDA graph support for the deepep_high_throughput backend when data parallelism is used (dp_size > 1), which was previously disabled due to a cache hit issue. The change introduces logic to conditionally add MoE-related operations (vllm::moe_forward, vllm::moe_forward_shared) to the list of splitting_ops in the compilation configuration. This allows these operations to be excluded from the main CUDA graph, making them compatible with CUDA graph capture. The changes are well-targeted, guarded by appropriate conditions, and include the removal of the old code that disabled this feature. The implementation appears correct and aligns with the goal of improving performance, as demonstrated by the benchmark results in the description.

@yewentao256 yewentao256 added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 27, 2025
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

if self.use_inductor_graph_partition:
self.set_splitting_ops_for_inductor_graph_partition()
return

P1 Badge Split DeepEP MoE ops when using inductor partition

For the DeepEP high-throughput backend with data-parallel >1, CUDA graphs are now enabled after removing the guard in vllm/platforms/cuda.py, but this early return prevents the new MoE split logic below from running when use_inductor_graph_partition is set. In that configuration the DeepEP MoE kernels remain inside captured CUDA graphs, recreating the incompatibility that the MoE splits were meant to avoid. The inductor-partition path should also mark the MoE ops as splitting ops before returning.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Copy link
Copy Markdown
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for implementing this, just make sure to handle the inductor partition case! Also could you add unit tests for the different cases?

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
@ProExpertProg ProExpertProg added this to the v0.12.0 milestone Dec 2, 2025
@ProExpertProg ProExpertProg removed this from the v0.12.0 milestone Dec 3, 2025
Copy link
Copy Markdown
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just minor comments on reorganization

Signed-off-by: yewentao256 <zhyanwentao@126.com>
@mergify
Copy link
Copy Markdown

mergify bot commented Dec 5, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @yewentao256.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Dec 5, 2025
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
@mergify mergify bot removed the needs-rebase label Dec 5, 2025
@ProExpertProg ProExpertProg enabled auto-merge (squash) December 6, 2025 00:01
@ProExpertProg ProExpertProg merged commit 17eb25e into main Dec 7, 2025
51 checks passed
@ProExpertProg ProExpertProg deleted the wentao-enable-cudagraph-for-deepepHT branch December 7, 2025 04:44
@github-project-automation github-project-automation bot moved this from In review to Done in NVIDIA Dec 7, 2025
@zhuohan123
Copy link
Copy Markdown
Member

I am genuinely confused on code structure reflected in this pr: the change is very kernel specific, but why does all the code changes are in the core vllm?

@yewentao256
Copy link
Copy Markdown
Member Author

I am genuinely confused on code structure reflected in this pr: the change is very kernel specific, but why does all the code changes are in the core vllm?

If we put it in the later stage, eg. deepep_ht.py, it is too late for splitting the ops. This is strongly related with cuda graph.

khluu pushed a commit that referenced this pull request Dec 18, 2025
…30910)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
khluu pushed a commit that referenced this pull request Dec 18, 2025
…30910)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
(cherry picked from commit 30bb19a)
yewentao256 added a commit that referenced this pull request Dec 18, 2025
…upport) (#30910)"

This reverts commit 30bb19a.

Signed-off-by: yewentao256 <zhyanwentao@126.com>
yugong333 pushed a commit to yugong333/vllm that referenced this pull request Dec 22, 2025
…CG support) (vllm-project#30910)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
wangxiyuan pushed a commit to vllm-project/vllm-ascend that referenced this pull request Dec 23, 2025
### What this PR does / why we need it?

### Does this PR introduce _any_ user-facing change?
Fix vllm break:
1. [Enable cuda graph for deepepHT, 5.3% throughput improvement, 4.4%
TTFT improvement] (vllm-project/vllm#29558)
Fix Solution: Add the now-necessary `all2all_backend` parameter. The
impact of this parameter on the original `set_splitting_ops_for_v1`
implementation is only that graph mode is disabled in `vllm` if
`deepep_high_throughput` is enabled; it has no effect on the
`vllm-ascend` logic.

2.[Migrate legacy ViT MultiHeadAttention to new MMEncoderAttention
interface ] (vllm-project/vllm#30684)
Fix Solution: The reason why the GPU does not need to convert qkv to 3D
is that the GPU's flash_attention operator is compatible with 3D and 4D
(b s h d and s b ( h d)), but the NPU's flash_attention_unpad operator
only supports 3D (s b ( h d)). Therefore, we need to introduce the
reshape_qkv_to_3d operation.

4.Skip Tencent-Hunyuan/HunyuanOCR test case, as it has following issue
in upgrade vllm code:
#5297

### How was this patch tested?


Co-authored-by: zxwang <1476209578@qq.com>

- vLLM version: release/v0.13.0
- vLLM main:
vllm-project/vllm@ad32e3e

---------

Signed-off-by: leo-pony <nengjunma@outlook.com>
Signed-off-by: zxwang <1476209578@qq.com>
Co-authored-by: zxwang <1476209578@qq.com>
Majid-Taheri pushed a commit to Majid-Taheri/vllm that referenced this pull request Dec 23, 2025
…CG support) (vllm-project#30910)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: Ubuntu <mjtaheri68@gmail.com>
fort726 pushed a commit to fort726/vllm that referenced this pull request Jan 6, 2026
…CG support) (vllm-project#30910)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
….4% TTFT improvement (vllm-project#29558)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
…CG support) (vllm-project#30910)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
ItzDEXX pushed a commit to ItzDEXX/vllm that referenced this pull request Feb 19, 2026
…CG support) (vllm-project#30910)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Feb 28, 2026
### What this PR does / why we need it?

### Does this PR introduce _any_ user-facing change?
Fix vllm break:
1. [Enable cuda graph for deepepHT, 5.3% throughput improvement, 4.4%
TTFT improvement] (vllm-project/vllm#29558)
Fix Solution: Add the now-necessary `all2all_backend` parameter. The
impact of this parameter on the original `set_splitting_ops_for_v1`
implementation is only that graph mode is disabled in `vllm` if
`deepep_high_throughput` is enabled; it has no effect on the
`vllm-ascend` logic.

2.[Migrate legacy ViT MultiHeadAttention to new MMEncoderAttention
interface ] (vllm-project/vllm#30684)
Fix Solution: The reason why the GPU does not need to convert qkv to 3D
is that the GPU's flash_attention operator is compatible with 3D and 4D
(b s h d and s b ( h d)), but the NPU's flash_attention_unpad operator
only supports 3D (s b ( h d)). Therefore, we need to introduce the
reshape_qkv_to_3d operation.

4.Skip Tencent-Hunyuan/HunyuanOCR test case, as it has following issue
in upgrade vllm code:
vllm-project#5297

### How was this patch tested?

Co-authored-by: zxwang <1476209578@qq.com>

- vLLM version: release/v0.13.0
- vLLM main:
vllm-project/vllm@ad32e3e

---------

Signed-off-by: leo-pony <nengjunma@outlook.com>
Signed-off-by: zxwang <1476209578@qq.com>
Co-authored-by: zxwang <1476209578@qq.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Mar 4, 2026
### What this PR does / why we need it?

### Does this PR introduce _any_ user-facing change?
Fix vllm break:
1. [Enable cuda graph for deepepHT, 5.3% throughput improvement, 4.4%
TTFT improvement] (vllm-project/vllm#29558)
Fix Solution: Add the now-necessary `all2all_backend` parameter. The
impact of this parameter on the original `set_splitting_ops_for_v1`
implementation is only that graph mode is disabled in `vllm` if
`deepep_high_throughput` is enabled; it has no effect on the
`vllm-ascend` logic.

2.[Migrate legacy ViT MultiHeadAttention to new MMEncoderAttention
interface ] (vllm-project/vllm#30684)
Fix Solution: The reason why the GPU does not need to convert qkv to 3D
is that the GPU's flash_attention operator is compatible with 3D and 4D
(b s h d and s b ( h d)), but the NPU's flash_attention_unpad operator
only supports 3D (s b ( h d)). Therefore, we need to introduce the
reshape_qkv_to_3d operation.

4.Skip Tencent-Hunyuan/HunyuanOCR test case, as it has following issue
in upgrade vllm code:
vllm-project#5297

### How was this patch tested?

Co-authored-by: zxwang <1476209578@qq.com>

- vLLM version: release/v0.13.0
- vLLM main:
vllm-project/vllm@ad32e3e

---------

Signed-off-by: leo-pony <nengjunma@outlook.com>
Signed-off-by: zxwang <1476209578@qq.com>
Co-authored-by: zxwang <1476209578@qq.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

nvidia ready ONLY add when PR is ready to merge/full CI is needed

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

3 participants