Skip to content

[MP] Support MP Server restart#2713

Merged
maobaolong merged 6 commits intoLMCache:devfrom
maobaolong:mp_server_restart
Mar 10, 2026
Merged

[MP] Support MP Server restart#2713
maobaolong merged 6 commits intoLMCache:devfrom
maobaolong:mp_server_restart

Conversation

@maobaolong
Copy link
Copy Markdown
Collaborator

@maobaolong maobaolong commented Mar 7, 2026

What this PR does / why we need it:

introduces an automatic re-registration capability for KV caches on the client side and enhances error handling on the server side, ensuring that operations can resume seamlessly after a server interruption. These changes make the system more resilient and reliable in dynamic environments.

Special notes for your reviewers:

If applicable:

  • this PR contains user facing changes - docs added
  • this PR contains unit tests

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly improves the robustness of the LMCache system by addressing the scenario where the multi-process server restarts. It introduces an automatic re-registration capability for KV caches on the client side and enhances error handling on the server side, ensuring that operations can resume seamlessly after a server interruption. These changes make the system more resilient and reliable in dynamic environments.

Highlights

  • Automatic KV Cache Re-registration: Implemented an automatic re-registration mechanism for KV caches in the vllm_multi_process_adapter to ensure resilience after an LMCache server restart. If a store or retrieve request fails due to an unregistered instance, the adapter will attempt to re-register its KV caches.
  • Improved Server-Side Error Handling: Modified the LMCache server's store and retrieve methods to gracefully handle requests from unregistered instances. Instead of asserting, the server now logs a warning and returns a failure, allowing the client to detect the server's state change.
  • Robust Future Handling: Enhanced the CudaFuture class to correctly manage futures that complete without a CUDA event, such as when a server returns an empty handle due to an unregistered instance. This prevents crashes and allows for proper error propagation.
  • New Integration Tests: Added a comprehensive test suite (test_server_restart.py) to validate the server restart resilience, covering both protocol-level failure handling and full integration scenarios with automatic re-registration.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • lmcache/integration/vllm/vllm_multi_process_adapter.py
    • Added a registered boolean flag to track registration status.
    • Refactored KV cache registration logic into a new _do_register method for reusability.
    • Introduced logic in get_finished to detect failed store/retrieve requests and trigger re-registration if necessary.
  • lmcache/v1/multiprocess/futures.py
    • Modified _on_raw_future_complete to handle cases where no CUDA event is returned, setting is_done_ flag.
    • Updated wait and query methods to check is_done_ for futures that complete without a CUDA event.
    • Replaced assert with if checks for self.event_ to prevent errors when no CUDA event is present.
  • lmcache/v1/multiprocess/server.py
    • Changed assert statements in store and retrieve methods to if conditions, returning b"", False and logging a warning if an instance is not registered.
  • tests/v1/multiprocess/test_server_restart.py
    • Added a new test file to cover server restart scenarios.
    • Included a protocol-level test to verify store failure on unregistered instances.
    • Implemented an integration test to confirm successful re-registration and subsequent operations after a server restart.
Activity
  • The pull request description is boilerplate and does not contain specific activity details.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for LMCache server restarts by making the client-side adapter more resilient. The changes include modifying the server to gracefully handle requests for unregistered instances, updating the client-side futures to handle failure responses, implementing auto-re-registration logic in the vLLM adapter, and adding a comprehensive integration test. The overall approach is solid. I have a couple of suggestions to further improve the robustness of the re-registration logic by adding timeouts and better exception handling to prevent the worker from hanging or crashing if the server is temporarily unavailable. I also have a minor suggestion to improve the clarity of one of the new tests.

Comment thread lmcache/integration/vllm/vllm_multi_process_adapter.py Outdated
Comment thread lmcache/integration/vllm/vllm_multi_process_adapter.py Outdated
Comment thread tests/v1/multiprocess/test_server_restart.py Outdated
Copy link
Copy Markdown
Collaborator

@DongDongJu DongDongJu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @maobaolong, thanks for the work.
Generally, LGTM.
But I have one concern that need_reregister decide logic is too simple now since retrieve return other failure too. (currently only two return True/False).
But we can handle it next pr.

@maobaolong maobaolong added the full Run comprehensive tests on this PR label Mar 10, 2026
Signed-off-by: baoloongmao <baoloongmao@tencent.com>
Signed-off-by: baoloongmao <baoloongmao@tencent.com>
Signed-off-by: baoloongmao <baoloongmao@tencent.com>
Signed-off-by: baoloongmao <baoloongmao@tencent.com>
@maobaolong maobaolong enabled auto-merge (squash) March 10, 2026 07:41
@maobaolong
Copy link
Copy Markdown
Collaborator Author

@DongDongJu Thanks for your review and good suggestion, i addressed it.
@chunxiaozheng Would you like to take a look at this PR.

Signed-off-by: baoloongmao <baoloongmao@tencent.com>
Signed-off-by: baoloongmao <baoloongmao@tencent.com>
Copy link
Copy Markdown
Collaborator

@chunxiaozheng chunxiaozheng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@maobaolong maobaolong merged commit 55622e8 into LMCache:dev Mar 10, 2026
30 of 34 checks passed
Copy link
Copy Markdown
Contributor

@ApostaC ApostaC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maobaolong this one actually incurs a few problems... The default value of LMCACHE_REGISTER_TIMEOUT is too small. And right now, this may have some design conflict with #2692

Can we revert this PR temporarily? To support server restart, a better solution is to make the mq_client be able to reconnect.

cc @chunxiaozheng @DongDongJu

Comment on lines +138 to +142
if self.is_done_.is_set():
# Completed without a CUDA event (e.g. server
# returned failure with empty handle).
return True

Copy link
Copy Markdown
Contributor

@ApostaC ApostaC Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a race condition here.

Thread A (_on_raw_future_complete): run self.is_done_.set()
Thread B (wait): run if self.event_: and found self.event_ is None
Thread B (wait): run if self.is_done_.set(): return True
Thread A (_on_raw_future_complete): run self.event_ = torch.cuda.Event.from_ipc_handle(self.device_, event_bytes) that initializes the real cuda event

This race condition will cause correctness issues or un-needed re-registration which have undefined behavior.

],
)
future.result()
timeout = float(os.getenv("LMCACHE_REGISTER_TIMEOUT", "5.0"))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This default timeout is too short. The LMCache side need to initialize the cuda stream and cupy, which may take ~2 seconds for a single TP worker.

ApostaC added a commit that referenced this pull request Mar 10, 2026
ApostaC added a commit that referenced this pull request Mar 10, 2026
This reverts commit 55622e8.

Signed-off-by: ApostaC <yihua98@uchicago.edu>
ApostaC added a commit that referenced this pull request Mar 10, 2026
This reverts commit 55622e8.

Signed-off-by: ApostaC <yihua98@uchicago.edu>
shaoxiawjc pushed a commit to shaoxiawjc/LMCache that referenced this pull request Mar 11, 2026
* [MP] Support MP Server restart

Signed-off-by: baoloongmao <baoloongmao@tencent.com>

* Address comment

Signed-off-by: baoloongmao <baoloongmao@tencent.com>

* Address comment from dongdongju

Signed-off-by: baoloongmao <baoloongmao@tencent.com>

* Fix related failed UT.

Signed-off-by: baoloongmao <baoloongmao@tencent.com>

---------

Signed-off-by: baoloongmao <baoloongmao@tencent.com>
Signed-off-by: shaoxiawjc <wjc2800@163.com>
shaoxiawjc pushed a commit to shaoxiawjc/LMCache that referenced this pull request Mar 11, 2026
This reverts commit 55622e8.

Signed-off-by: ApostaC <yihua98@uchicago.edu>
Signed-off-by: shaoxiawjc <wjc2800@163.com>
realAaronWu pushed a commit to realAaronWu/LMCache that referenced this pull request Mar 20, 2026
* [MP] Support MP Server restart

Signed-off-by: baoloongmao <baoloongmao@tencent.com>

* Address comment

Signed-off-by: baoloongmao <baoloongmao@tencent.com>

* Address comment from dongdongju

Signed-off-by: baoloongmao <baoloongmao@tencent.com>

* Fix related failed UT.

Signed-off-by: baoloongmao <baoloongmao@tencent.com>

---------

Signed-off-by: baoloongmao <baoloongmao@tencent.com>
Signed-off-by: Aaron Wu <aaron.wu@dell.com>
realAaronWu pushed a commit to realAaronWu/LMCache that referenced this pull request Mar 20, 2026
This reverts commit 55622e8.

Signed-off-by: ApostaC <yihua98@uchicago.edu>
Signed-off-by: Aaron Wu <aaron.wu@dell.com>
jooho-XCENA pushed a commit to xcena-dev/LMCache that referenced this pull request Apr 2, 2026
* [MP] Support MP Server restart

Signed-off-by: baoloongmao <baoloongmao@tencent.com>

* Address comment

Signed-off-by: baoloongmao <baoloongmao@tencent.com>

* Address comment from dongdongju

Signed-off-by: baoloongmao <baoloongmao@tencent.com>

* Fix related failed UT.

Signed-off-by: baoloongmao <baoloongmao@tencent.com>

---------

Signed-off-by: baoloongmao <baoloongmao@tencent.com>
jooho-XCENA pushed a commit to xcena-dev/LMCache that referenced this pull request Apr 2, 2026
This reverts commit 55622e8.

Signed-off-by: ApostaC <yihua98@uchicago.edu>
jooho-XCENA pushed a commit to xcena-dev/LMCache that referenced this pull request Apr 2, 2026
* [MP] Support MP Server restart

Signed-off-by: baoloongmao <baoloongmao@tencent.com>

* Address comment

Signed-off-by: baoloongmao <baoloongmao@tencent.com>

* Address comment from dongdongju

Signed-off-by: baoloongmao <baoloongmao@tencent.com>

* Fix related failed UT.

Signed-off-by: baoloongmao <baoloongmao@tencent.com>

---------

Signed-off-by: baoloongmao <baoloongmao@tencent.com>
jooho-XCENA pushed a commit to xcena-dev/LMCache that referenced this pull request Apr 2, 2026
This reverts commit 55622e8.

Signed-off-by: ApostaC <yihua98@uchicago.edu>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

full Run comprehensive tests on this PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants