Skip to content

[MP][Bugfix] fix vllm-side lookup logical issue and cuda stream deadlock problem#2733

Merged
ApostaC merged 3 commits intoLMCache:devfrom
ApostaC:local-dev/mp-fix-hang
Mar 11, 2026
Merged

[MP][Bugfix] fix vllm-side lookup logical issue and cuda stream deadlock problem#2733
ApostaC merged 3 commits intoLMCache:devfrom
ApostaC:local-dev/mp-fix-hang

Conversation

@ApostaC
Copy link
Copy Markdown
Contributor

@ApostaC ApostaC commented Mar 11, 2026

What this PR does / why we need it:

Fixes two bugs in the multi-process (MP) mode that together cause hangs:

  1. vLLM-side lookup result caching: After a prefetch job completes, the server removes it from its job map. Subsequent QUERY_PREFETCH_STATUS calls for that job ID return None (previously 0), causing the scheduler to think the lookup is still in progress and poll forever. This PR:

    • Caches finished lookup results on the vLLM side (_finished_lookup_jobs) so repeated queries return the cached count instead of re-querying the server.
    • Changes the server to return None (instead of 0) for unknown job IDs, making the semantics consistent: None = not ready / unknown, int = done.
    • Cleans up cached results in cleanup_lookup_result.
  2. CUDA stream deadlock from launch_host_func: Telemetry logging callbacks scheduled via cupy_stream.launch_host_func need the GIL, but paged_kv_transfer and lmcache_memcpy_async hold the GIL while waiting on the CUDA stream — creating a deadlock. This PR adds py::call_guard<py::gil_scoped_release>() to both C++ bindings so the GIL is released during execution, allowing host callbacks to acquire it. This also removes the _can_log_store workaround that was masking the issue.

Special notes for your reviewers:

  • The GIL release guard is safe for these two functions because they operate on GPU memory via CUDA streams and do not touch Python objects.
  • The _can_log_store flag was a partial workaround that prevented store telemetry until after the first retrieve; it is no longer needed with the proper GIL fix.

If applicable:

  • this PR contains user facing changes - docs added
  • this PR contains unit tests

ApostaC added 2 commits March 11, 2026 01:45
Signed-off-by: ApostaC <yihua98@uchicago.edu>
Signed-off-by: ApostaC <yihua98@uchicago.edu>
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses two critical bugs in the multi-process mode that were causing system hangs. It introduces a caching mechanism for vLLM prefetch lookup results to ensure that completed jobs are correctly recognized, preventing infinite polling. Additionally, it resolves a CUDA stream deadlock by properly managing Python's Global Interpreter Lock during GPU memory transfer operations, allowing host callbacks to execute without contention. These changes significantly improve the stability and reliability of the multi-process system.

Highlights

  • vLLM-side lookup result caching: Implemented caching for finished prefetch job results on the vLLM side to prevent the scheduler from polling indefinitely for completed jobs. The server now consistently returns None for unknown job IDs, clarifying lookup status.
  • CUDA stream deadlock fix: Resolved a deadlock issue in multi-process mode by adding py::call_guardpy::gil_scoped_release() to paged_kv_transfer and lmcache_memcpy_async C++ bindings, ensuring the Global Interpreter Lock (GIL) is released during GPU operations and allowing host callbacks to execute.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • csrc/pybind.cpp
    • Added py::call_guardpy::gil_scoped_release() to the paged_kv_transfer function definition.
    • Added py::call_guardpy::gil_scoped_release() to the lmcache_memcpy_async function definition.
  • lmcache/integration/vllm/vllm_multi_process_adapter.py
    • Introduced a new dictionary, _finished_lookup_jobs, to cache completed prefetch job results.
    • Modified check_lookup_result to return cached results if available and to store new results upon completion.
    • Updated cleanup_lookup_result to also remove corresponding job IDs from the _finished_lookup_jobs cache.
  • lmcache/v1/multiprocess/server.py
    • Removed the _can_log_store flag and its associated initialization.
    • Removed the _can_log_store condition from telemetry logging calls within the store method.
    • Modified query_prefetch_status to return None instead of 0 for unknown or already completed prefetch job IDs.
Activity
  • No specific activity (comments, reviews, progress) has been recorded for this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively addresses two critical bugs causing hangs in the multi-process mode, including a CUDA stream deadlock and an infinite polling issue. While these fixes enhance stability, the change for the server to return None for unknown job IDs, intended to improve API consistency, inadvertently introduces a Denial of Service vulnerability. Conflating the 'not found' state with the 'in progress' state (both returning None), especially when combined with predictable job IDs, could allow an attacker to trigger infinite polling loops in the scheduler. It is critical to address this by using unpredictable IDs and distinct return values for different job states. Additionally, there are a couple of minor suggestions to improve overall code quality and documentation clarity.

Comment thread lmcache/v1/multiprocess/server.py
Comment on lines +214 to +216
if job_id in self._finished_lookup_jobs:
# Return cached result if the job is already finished
return self._finished_lookup_jobs[job_id] * self.chunk_size
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This check for a cached result can be made more concise and efficient by using dict.get() to avoid two separate dictionary lookups (in followed by []). Using an assignment expression (the walrus operator :=, available in Python 3.8+) makes this pattern particularly clean and is generally preferred for this type of check.

Suggested change
if job_id in self._finished_lookup_jobs:
# Return cached result if the job is already finished
return self._finished_lookup_jobs[job_id] * self.chunk_size
if (cached_result := self._finished_lookup_jobs.get(job_id)) is not None:
# Return cached result if the job is already finished
return cached_result * self.chunk_size

Comment on lines +647 to +648
Chunk count (int) when done, None if still in progress
or the job ID is unknown.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The updated docstring is slightly ambiguous. A clearer phrasing would improve readability and avoid any potential misinterpretation of when None is returned.

Suggested change
Chunk count (int) when done, None if still in progress
or the job ID is unknown.
The number of matched chunks (int) if the prefetch job is complete.
Returns None if the job is still in progress or the job ID is unknown.

@ApostaC ApostaC added the full Run comprehensive tests on this PR label Mar 11, 2026
Signed-off-by: ApostaC <yihua98@uchicago.edu>
Copy link
Copy Markdown
Contributor

@sammshen sammshen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

Copy link
Copy Markdown
Contributor

@KuntaiDu KuntaiDu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Copy Markdown
Contributor

@KuntaiDu KuntaiDu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

L. G. T. M.

@ApostaC ApostaC merged commit dfc914c into LMCache:dev Mar 11, 2026
27 of 28 checks passed
shaoxiawjc pushed a commit to shaoxiawjc/LMCache that referenced this pull request Mar 11, 2026
…ock problem (LMCache#2733)

* Fix the vLLM-side logical bug

* [fix] deadlock problem caused by launch_host_func

Signed-off-by: ApostaC <yihua98@uchicago.edu>
Signed-off-by: shaoxiawjc <wjc2800@163.com>
realAaronWu pushed a commit to realAaronWu/LMCache that referenced this pull request Mar 20, 2026
…ock problem (LMCache#2733)

* Fix the vLLM-side logical bug

* [fix] deadlock problem caused by launch_host_func

Signed-off-by: ApostaC <yihua98@uchicago.edu>
Signed-off-by: Aaron Wu <aaron.wu@dell.com>
jooho-XCENA pushed a commit to xcena-dev/LMCache that referenced this pull request Apr 2, 2026
…ock problem (LMCache#2733)

* Fix the vLLM-side logical bug

* [fix] deadlock problem caused by launch_host_func

Signed-off-by: ApostaC <yihua98@uchicago.edu>
jooho-XCENA pushed a commit to xcena-dev/LMCache that referenced this pull request Apr 2, 2026
…ock problem (LMCache#2733)

* Fix the vLLM-side logical bug

* [fix] deadlock problem caused by launch_host_func

Signed-off-by: ApostaC <yihua98@uchicago.edu>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

full Run comprehensive tests on this PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants