Skip to content

[Bug] Fix torch Compilation Cache Hit Error#25093

Merged
simon-mo merged 3 commits intovllm-project:mainfrom
neuralmagic:wye-remove-deepep-HT-support-for-piecewise-cudagraph
Sep 18, 2025
Merged

[Bug] Fix torch Compilation Cache Hit Error#25093
simon-mo merged 3 commits intovllm-project:mainfrom
neuralmagic:wye-remove-deepep-HT-support-for-piecewise-cudagraph

Conversation

@yewentao256
Copy link
Copy Markdown
Member

@yewentao256 yewentao256 commented Sep 17, 2025

Purpose

Fixes #24915

The root cause of this issue is from (runtime_shape, graph_index, backend_name) is not a strong enough cache key for compiled cuda graph. And in complicated situation like DeepEP HT, we split the piecewise graph moe_forward and moe_forward_shared and make the wrong cache hit.

I am not sure if it is a good idea to refactor the cache key system throughly just for the support of piecewise graph for HT (Perhaps not worth enough), so simply cancel the support for HT graph now. If we encounter other scenarios where the cache key proves insufficient, we should revisit and redesign the cache system.

Note: The rough idea to refactor the cache system: Adding a fourth key introducing the signature of sub_graph like. Works good locally

def compute_subgraph_signature(graph: fx.GraphModule) -> str:
    parts: list[str] = []
    for node in graph.graph.nodes:
        parts.append(f"{node.op}:{str(node.target)}")
    sig = "|".join(parts)
    return hashlib.md5(sig.encode(), usedforsecurity=False).hexdigest()[:16]

Test

Originally: Wrong graph cache hit in the second run

(EngineCore_DP6 pid=2464687)     hidden_states = self.model(input_ids, positions, intermediate_tensors,
(EngineCore_DP6 pid=2464687)                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP6 pid=2464687)   File "/data/vllm-community-homes/vllm-user-6/vllm/vllm/compilation/decorators.py", line 305, in __call__
(EngineCore_DP6 pid=2464687)     output = self.compiled_callable(*args, **kwargs)
(EngineCore_DP6 pid=2464687)              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP6 pid=2464687)   File "/data/vllm-community-homes/vllm-user-6/.venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 655, in _fn
(EngineCore_DP6 pid=2464687)     return fn(*args, **kwargs)
(EngineCore_DP6 pid=2464687)            ^^^^^^^^^^^^^^^^^^^
(EngineCore_DP6 pid=2464687)   File "/data/vllm-community-homes/vllm-user-6/vllm/vllm/model_executor/models/deepseek_v2.py", line 767, in forward
(EngineCore_DP6 pid=2464687)     def forward(
(EngineCore_DP6 pid=2464687)   File "/data/vllm-community-homes/vllm-user-6/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
(EngineCore_DP6 pid=2464687)     return self._call_impl(*args, **kwargs)
(EngineCore_DP6 pid=2464687)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP6 pid=2464687)   File "/data/vllm-community-homes/vllm-user-6/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
(EngineCore_DP6 pid=2464687)     return forward_call(*args, **kwargs)
(EngineCore_DP6 pid=2464687)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP6 pid=2464687)   File "/data/vllm-community-homes/vllm-user-6/.venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 838, in _fn
(EngineCore_DP6 pid=2464687)     return fn(*args, **kwargs)
(EngineCore_DP6 pid=2464687)            ^^^^^^^^^^^^^^^^^^^
(EngineCore_DP6 pid=2464687)   File "/data/vllm-community-homes/vllm-user-6/.venv/lib/python3.12/site-packages/torch/fx/graph_module.py", line 830, in call_wrapped
(EngineCore_DP6 pid=2464687)     return self._wrapped_call(self, *args, **kwargs)
(EngineCore_DP6 pid=2464687)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP6 pid=2464687)   File "/data/vllm-community-homes/vllm-user-6/.venv/lib/python3.12/site-packages/torch/fx/graph_module.py", line 406, in __call__
(EngineCore_DP6 pid=2464687)     raise e
(EngineCore_DP6 pid=2464687)   File "/data/vllm-community-homes/vllm-user-6/.venv/lib/python3.12/site-packages/torch/fx/graph_module.py", line 393, in __call__
(EngineCore_DP6 pid=2464687)     return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
(EngineCore_DP6 pid=2464687)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP6 pid=2464687)   File "/data/vllm-community-homes/vllm-user-6/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
(EngineCore_DP6 pid=2464687)     return self._call_impl(*args, **kwargs)
(EngineCore_DP6 pid=2464687)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP6 pid=2464687)   File "/data/vllm-community-homes/vllm-user-6/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
(EngineCore_DP6 pid=2464687)     return forward_call(*args, **kwargs)
(EngineCore_DP6 pid=2464687)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP6 pid=2464687)   File "<eval_with_key>.240", line 723, in forward
(EngineCore_DP6 pid=2464687)     submod_10 = self.submod_10(submod_9, s0, getitem_22, l_self_modules_layers_modules_4_modules_input_layernorm_parameters_weight_, l_self_modules_layers_modules_4_modules_self_attn_modules_mla_attn_modules_fused_qkv_a_proj_parameters_weight_, l_self_modules_layers_modules_4_modules_self_attn_modules_mla_attn_modules_fused_qkv_a_proj_parameters_weight_scale_inv_, l_self_modules_layers_modules_4_modules_self_attn_modules_mla_attn_modules_q_a_layernorm_parameters_weight_, l_self_modules_layers_modules_4_modules_self_attn_modules_mla_attn_modules_q_b_proj_parameters_weight_, l_self_modules_layers_modules_4_modules_self_attn_modules_mla_attn_modules_q_b_proj_parameters_weight_scale_inv_, l_self_modules_layers_modules_4_modules_self_attn_modules_mla_attn_modules_kv_a_layernorm_parameters_weight_, l_self_modules_layers_modules_0_modules_self_attn_modules_mla_attn_modules_rotary_emb_buffers_cos_sin_cache_, l_positions_);  submod_9 = getitem_22 = l_self_modules_layers_modules_4_modules_input_layernorm_parameters_weight_ = l_self_modules_layers_modules_4_modules_self_attn_modules_mla_attn_modules_fused_qkv_a_proj_parameters_weight_ = l_self_modules_layers_modules_4_modules_self_attn_modules_mla_attn_modules_fused_qkv_a_proj_parameters_weight_scale_inv_ = l_self_modules_layers_modules_4_modules_self_attn_modules_mla_attn_modules_q_a_layernorm_parameters_weight_ = l_self_modules_layers_modules_4_modules_self_attn_modules_mla_attn_modules_q_b_proj_parameters_weight_ = l_self_modules_layers_modules_4_modules_self_attn_modules_mla_attn_modules_q_b_proj_parameters_weight_scale_inv_ = l_self_modules_layers_modules_4_modules_self_attn_modules_mla_attn_modules_kv_a_layernorm_parameters_weight_ = None
(EngineCore_DP6 pid=2464687)                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP6 pid=2464687)   File "/data/vllm-community-homes/vllm-user-6/vllm/vllm/compilation/cuda_graph.py", line 119, in __call__
(EngineCore_DP6 pid=2464687)     return self.runnable(*args, **kwargs)
(EngineCore_DP6 pid=2464687)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP6 pid=2464687)   File "/data/vllm-community-homes/vllm-user-6/vllm/vllm/compilation/cuda_piecewise_backend.py", line 90, in __call__
(EngineCore_DP6 pid=2464687)     return self.compiled_graph_for_general_shape(*args)
(EngineCore_DP6 pid=2464687)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP6 pid=2464687)   File "/data/vllm-community-homes/vllm-user-6/vllm/vllm/compilation/compiler_interface.py", line 518, in compiled_graph
(EngineCore_DP6 pid=2464687)     graph_output = inductor_compiled_graph(list_args)
(EngineCore_DP6 pid=2464687)                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP6 pid=2464687)   File "/data/vllm-community-homes/vllm-user-6/.venv/lib/python3.12/site-packages/torch/_inductor/output_code.py", line 460, in __call__
(EngineCore_DP6 pid=2464687)     return self.current_callable(inputs)
(EngineCore_DP6 pid=2464687)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP6 pid=2464687)   File "/data/vllm-community-homes/vllm-user-6/.cache/vllm/torch_compile_cache/827a4e48c2/rank_0_6/inductor_cache/u3/cu35tgdepztiquv25xuh6ksxtdembn2u7jrhps4jbvwlcnvv4bzl.py", line 441, in call
(EngineCore_DP6 pid=2464687)     arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1 = args
(EngineCore_DP6 pid=2464687)     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP1 pid=2464682)     return self.compiled_graph_for_general_shape(*args)
(EngineCore_DP6 pid=2464687) ValueError: not enough values to unpack (expected 13, got 12)
(EngineCore_DP1 pid=2464682)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP1 pid=2464682)   File "/data/vllm-community-homes/vllm-user-6/vllm/vllm/compilation/compiler_interface.py", line 518, in compiled_graph
(EngineCore_DP1 pid=2464682)     graph_output = inductor_compiled_graph(list_args)
(EngineCore_DP1 pid=2464682)                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP1 pid=2464682)   File "/data/vllm-community-homes/vllm-user-6/.venv/lib/python3.12/site-packages/torch/_inductor/output_code.py", line 460, in __call__
(EngineCore_DP1 pid=2464682)     return self.current_callable(inputs)
(EngineCore_DP1 pid=2464682)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP1 pid=2464682)   File "/data/vllm-community-homes/vllm-user-6/.cache/vllm/torch_compile_cache/827a4e48c2/rank_0_1/inductor_cache/5x/c5xlzojszbnaoiyl4la7syjq6lh2rfdc3q3emyo3ase2b6w6hizw.py", line 441, in call
(EngineCore_DP1 pid=2464682)     arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1 = args
(EngineCore_DP1 pid=2464682)     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP1 pid=2464682) ValueError: not enough values to unpack (expected 13, got 12)

Now:

(APIServer pid=1527118) INFO:     Started server process [1527118]
(APIServer pid=1527118) INFO:     Waiting for application startup.
(APIServer pid=1527118) INFO:     Application startup complete.

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a critical bug where an insufficient cache key for compiled CUDA graphs caused cache collisions and errors when using the deepep_high_throughput backend. The fix correctly disables CUDA graphs for this specific configuration, preventing the crash. This is a solid, pragmatic solution for the immediate problem. I've added one high-severity suggestion to add a TODO comment to track the technical debt of this temporary fix, ensuring the long-term goal of re-enabling this performance feature with a more robust caching mechanism is not lost.

Copy link
Copy Markdown
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand, is this a torch compile caching issue or is it a CUDAGraph issue? These are (somewhat) orthogonal features. I don't know of any cudagraph caching. Also I think we should be able to disable CUDAGraphs but keep compilation (maybe it just can't be piecewise).

Signed-off-by: yewentao256 <zhyanwentao@126.com>
@yewentao256
Copy link
Copy Markdown
Member Author

I don't understand, is this a torch compile caching issue or is it a CUDAGraph issue? These are (somewhat) orthogonal features. I don't know of any cudagraph caching. Also I think we should be able to disable CUDAGraphs but keep compilation (maybe it just can't be piecewise).

@ProExpertProg This is a compile caching issue. But the splitting_ops is used for both torch compile and CudaGraphs.

Yeah we can keep compilation as it is, just disabling cuda graphs, fixed now.

Copy link
Copy Markdown
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems fine, but if we want more performance we could also just disable inductor compile caching (increases startup time but would give the best performance).

Signed-off-by: yewentao256 <zhyanwentao@126.com>
@yewentao256
Copy link
Copy Markdown
Member Author

This seems fine, but if we want more performance we could also just disable inductor compile caching (increases startup time but would give the best performance).

Yes, but seems that piecewise cuda graph for HT is mainly beneficial for decoding, for prefill we don't see too much performance improvement.

Copy link
Copy Markdown
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good

@ProExpertProg ProExpertProg enabled auto-merge (squash) September 18, 2025 14:39
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 18, 2025
@simon-mo simon-mo disabled auto-merge September 18, 2025 19:38
@simon-mo simon-mo merged commit d2a30a2 into vllm-project:main Sep 18, 2025
49 of 51 checks passed
@yewentao256 yewentao256 deleted the wye-remove-deepep-HT-support-for-piecewise-cudagraph branch September 18, 2025 20:03
Comment on lines +195 to +197
# TODO: Piecewise Cuda graph might be enabled
# if torch compile cache key issue fixed
# See https://github.com/vllm-project/vllm/pull/25093
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this a bug? can you file an issue if so?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a bug, it is just the cache key is not strong enough to support splitting. I think it is not worth doing the refactor just for the support of HT Piecewise cudagraph, so let's put it there.

ywang96 pushed a commit to ywang96/vllm that referenced this pull request Sep 19, 2025
Signed-off-by: yewentao256 <zhyanwentao@126.com>
debroy-rh pushed a commit to debroy-rh/vllm that referenced this pull request Sep 19, 2025
Signed-off-by: yewentao256 <zhyanwentao@126.com>
ABC12345anouys pushed a commit to ABC12345anouys/vllm that referenced this pull request Sep 25, 2025
Signed-off-by: yewentao256 <zhyanwentao@126.com>
charlifu pushed a commit to ROCm/vllm that referenced this pull request Sep 25, 2025
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: charlifu <charlifu@amd.com>
choprahetarth pushed a commit to Tandemn-Labs/vllm that referenced this pull request Oct 11, 2025
Signed-off-by: yewentao256 <zhyanwentao@126.com>
lywa1998 pushed a commit to lywa1998/vllm that referenced this pull request Oct 20, 2025
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: Consistent vLLM crash when trying to reuse previous torch compilation with DeepSeek R1 on B200

4 participants