Skip to content

[MP] Enable layout desc in MP lookup and prefetch#2662

Merged
YaoJiayi merged 4 commits intoLMCache:devfrom
ApostaC:local-dev/lookup-with-layout
Mar 2, 2026
Merged

[MP] Enable layout desc in MP lookup and prefetch#2662
YaoJiayi merged 4 commits intoLMCache:devfrom
ApostaC:local-dev/lookup-with-layout

Conversation

@ApostaC
Copy link
Copy Markdown
Contributor

@ApostaC ApostaC commented Mar 1, 2026

Summary

  • Pass layout_desc to storage_manager.submit_prefetch_task() in both MPCacheEngine.lookup and BlendEngine.cb_lookup_pre_computed, preparing for future PrefetchController implementation
  • Add gpu_context_meta dict to MPCacheEngine and _cb_gpu_context_meta to BlendEngine to track (model_name, world_size) per registered GPU instance, enabling layout desc derivation during lookup/prefetch
  • Update REGISTER_KV_CACHE and CB_REGISTER_KV_CACHE protocols to carry model_name and world_size
  • Change LOOKUP protocol from list[KeyType] to single KeyType payload
  • Extract get_layout_desc() helper to replace inline layout construction in store()
  • Update vllm_multi_process_adapter.py to pass model_name/world_size during registration and use single-key LOOKUP

Test plan

  • pytest tests/v1/multiprocess/test_mq.py — 10/10 passed
  • pytest tests/v1/multiprocess/test_cache_server.py — 8/8 passed
  • pytest tests/v1/multiprocess/test_blend_server.py — 21/21 passed

Signed-off-by: ApostaC <yihua98@uchicago.edu>
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the caching mechanism by integrating memory layout descriptions into prefetch operations. It achieves this by introducing model-specific metadata tracking for registered GPU instances, allowing for dynamic derivation of memory layouts. The changes also involve updating communication protocols to carry this new metadata and streamlining lookup requests, ensuring a more robust and efficient distributed caching system.

Highlights

  • Layout Description Integration: Enabled the passing of layout_desc to storage_manager.submit_prefetch_task() within MPCacheEngine.lookup and BlendEngine.cb_lookup_pre_computed, preparing for future PrefetchController implementations.
  • Model Metadata Tracking: Introduced gpu_context_meta in MPCacheEngine and _cb_gpu_context_meta in BlendEngine to track (model_name, world_size) for registered GPU instances, facilitating layout description derivation during lookup and prefetch operations.
  • Protocol Updates: Modified REGISTER_KV_CACHE and CB_REGISTER_KV_CACHE protocols to include model_name and world_size in their payloads, and streamlined the LOOKUP protocol to accept a single KeyType instead of a list.
  • Code Refactoring: Extracted a get_layout_desc() helper function to centralize layout description construction, replacing inline logic in store().
  • Adapter Alignment: Updated vllm_multi_process_adapter.py to correctly pass model_name and world_size during KV cache registration and to utilize the new single-key LOOKUP protocol.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • lmcache/integration/vllm/vllm_multi_process_adapter.py
    • Updated maybe_submit_lookup_request to send a single KeyType for LOOKUP requests.
    • Modified _create_key to set worker_id to None for scheduler adapter keys.
    • Extended register_kv_caches to pass model_name and world_size during REGISTER_KV_CACHE requests.
  • lmcache/v1/distributed/storage_manager.py
    • Added layout_desc parameter to the submit_prefetch_task method.
  • lmcache/v1/multiprocess/blend_server.py
    • Added _cb_gpu_context_meta dictionary to store model name and world size per CB GPU context.
    • Updated cb_register_kv_cache to accept model_name and world_size and store them in metadata.
    • Modified cb_unregister_kv_cache to remove corresponding metadata.
    • Implemented logic in cb_lookup_pre_computed to derive MemoryLayoutDesc using stored metadata and pass it to the storage manager.
  • lmcache/v1/multiprocess/protocols/blend.py
    • Updated CB_REGISTER_KV_CACHE protocol definition to include str (model_name) and int (world_size) in its payload classes.
  • lmcache/v1/multiprocess/protocols/engine.py
    • Updated REGISTER_KV_CACHE protocol definition to include str (model_name) and int (world_size) in its payload classes.
    • Changed LOOKUP protocol definition to accept a single KeyType instead of list[KeyType].
  • lmcache/v1/multiprocess/server.py
    • Added a new helper function get_layout_desc for creating MemoryLayoutDesc objects.
    • Introduced gpu_context_meta dictionary to MPCacheEngine for tracking model name and world size per GPU context.
    • Updated register_kv_cache to accept and store model_name and world_size.
    • Modified unregister_kv_cache to remove corresponding metadata.
    • Refactored store method to use the new get_layout_desc helper.
    • Updated lookup method signature to accept a single IPCCacheEngineKey.
    • Implemented logic in lookup to derive MemoryLayoutDesc using stored metadata and pass it to the storage manager.
  • tests/v1/multiprocess/test_blend_server.py
    • Updated CB_REGISTER_KV_CACHE calls in various tests to pass dummy model_name and world_size.
    • Adjusted LOOKUP request payloads to send a single key instead of a list of keys.
  • tests/v1/multiprocess/test_cache_server.py
    • Removed the lookup_keys helper function.
    • Added a new lookup_all helper function to iterate and perform individual LOOKUP requests.
    • Updated REGISTER_KV_CACHE calls in tests to include dummy model_name and world_size.
    • Replaced direct LOOKUP calls with the new lookup_all helper to align with the single-key protocol.
  • tests/v1/multiprocess/test_mq.py
    • Updated REGISTER_KV_CACHE test payloads to include model_name and world_size.
    • Modified LOOKUP tests to send a single KeyType payload and adjusted expected responses accordingly.
  • tests/v1/multiprocess/test_mq_handler_helpers.py
    • Updated register_kv_cache_handler to accept model_name and world_size parameters.
    • Modified lookup_handler to accept a single KeyType and return a fixed value of 1 for testing purposes.
Activity
  • The pull request was generated with Claude Code.
  • All provided tests (pytest tests/v1/multiprocess/test_mq.py, pytest tests/v1/multiprocess/test_cache_server.py, pytest tests/v1/multiprocess/test_blend_server.py) have passed successfully.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@ApostaC ApostaC requested review from KuntaiDu and YaoJiayi March 1, 2026 02:54
@ApostaC ApostaC added the mp Buildkite trigger for multi-processing mode test label Mar 1, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively enables the use of layout_desc in multiprocess lookup and prefetch operations. It achieves this by incorporating model_name and world_size metadata during GPU context registration, which is then used to derive the memory layout description. The simplification of the LOOKUP protocol to handle a single key is a good cleanup, and all related components, including tests, have been updated accordingly. The changes are logical and well-executed. I have a couple of suggestions to enhance performance and reduce code duplication.

Comment thread lmcache/v1/multiprocess/blend_server.py
Comment thread lmcache/v1/multiprocess/blend_server.py
Comment thread lmcache/v1/multiprocess/server.py
Signed-off-by: ApostaC <yihua98@uchicago.edu>
Signed-off-by: ApostaC <yihua98@uchicago.edu>
@ApostaC ApostaC added the full Run comprehensive tests on this PR label Mar 1, 2026
kv_caches: KVCache,
model_name: str,
world_size: int,
) -> None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any rationales behind adding model_name and world_size?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other part (changing key to single instead of multiple) LGTM

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It needs this to figure out how to determine the "layout spec" when allocating the prefetch buffer in L1.

Copy link
Copy Markdown
Contributor

@KuntaiDu KuntaiDu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@YaoJiayi YaoJiayi enabled auto-merge (squash) March 2, 2026 02:31
@YaoJiayi YaoJiayi merged commit 9c6368c into LMCache:dev Mar 2, 2026
25 checks passed
hlin99 pushed a commit to hlin99/LMCache that referenced this pull request Mar 2, 2026
* Enable layout desc in prefetch

Signed-off-by: ApostaC <yihua98@uchicago.edu>

* fix unit tests

Signed-off-by: ApostaC <yihua98@uchicago.edu>

* remove unused worker id in schedulr adapter

Signed-off-by: ApostaC <yihua98@uchicago.edu>

---------

Signed-off-by: ApostaC <yihua98@uchicago.edu>
Co-authored-by: Jiayi Yao <82156730+YaoJiayi@users.noreply.github.com>
oferki pushed a commit to oferki/LMCache that referenced this pull request Mar 3, 2026
* Enable layout desc in prefetch

Signed-off-by: ApostaC <yihua98@uchicago.edu>

* fix unit tests

Signed-off-by: ApostaC <yihua98@uchicago.edu>

* remove unused worker id in schedulr adapter

Signed-off-by: ApostaC <yihua98@uchicago.edu>

---------

Signed-off-by: ApostaC <yihua98@uchicago.edu>
Co-authored-by: Jiayi Yao <82156730+YaoJiayi@users.noreply.github.com>
Signed-off-by: Ofer Kiselov Nahman <ofer.kiselovnahman@weka.io>
oferki pushed a commit to oferki/LMCache that referenced this pull request Mar 3, 2026
* Enable layout desc in prefetch

Signed-off-by: ApostaC <yihua98@uchicago.edu>

* fix unit tests

Signed-off-by: ApostaC <yihua98@uchicago.edu>

* remove unused worker id in schedulr adapter

Signed-off-by: ApostaC <yihua98@uchicago.edu>

---------

Signed-off-by: ApostaC <yihua98@uchicago.edu>
Co-authored-by: Jiayi Yao <82156730+YaoJiayi@users.noreply.github.com>
mauryaavinash95 pushed a commit to mauryaavinash95/LMCache that referenced this pull request Mar 7, 2026
* Enable layout desc in prefetch

Signed-off-by: ApostaC <yihua98@uchicago.edu>

* fix unit tests

Signed-off-by: ApostaC <yihua98@uchicago.edu>

* remove unused worker id in schedulr adapter

Signed-off-by: ApostaC <yihua98@uchicago.edu>

---------

Signed-off-by: ApostaC <yihua98@uchicago.edu>
Co-authored-by: Jiayi Yao <82156730+YaoJiayi@users.noreply.github.com>
shaoxiawjc pushed a commit to shaoxiawjc/LMCache that referenced this pull request Mar 11, 2026
* Enable layout desc in prefetch

Signed-off-by: ApostaC <yihua98@uchicago.edu>

* fix unit tests

Signed-off-by: ApostaC <yihua98@uchicago.edu>

* remove unused worker id in schedulr adapter

Signed-off-by: ApostaC <yihua98@uchicago.edu>

---------

Signed-off-by: ApostaC <yihua98@uchicago.edu>
Co-authored-by: Jiayi Yao <82156730+YaoJiayi@users.noreply.github.com>
Signed-off-by: shaoxiawjc <wjc2800@163.com>
realAaronWu pushed a commit to realAaronWu/LMCache that referenced this pull request Mar 20, 2026
* Enable layout desc in prefetch

Signed-off-by: ApostaC <yihua98@uchicago.edu>

* fix unit tests

Signed-off-by: ApostaC <yihua98@uchicago.edu>

* remove unused worker id in schedulr adapter

Signed-off-by: ApostaC <yihua98@uchicago.edu>

---------

Signed-off-by: ApostaC <yihua98@uchicago.edu>
Co-authored-by: Jiayi Yao <82156730+YaoJiayi@users.noreply.github.com>
Signed-off-by: Aaron Wu <aaron.wu@dell.com>
jooho-XCENA pushed a commit to xcena-dev/LMCache that referenced this pull request Apr 2, 2026
* Enable layout desc in prefetch

Signed-off-by: ApostaC <yihua98@uchicago.edu>

* fix unit tests

Signed-off-by: ApostaC <yihua98@uchicago.edu>

* remove unused worker id in schedulr adapter

Signed-off-by: ApostaC <yihua98@uchicago.edu>

---------

Signed-off-by: ApostaC <yihua98@uchicago.edu>
Co-authored-by: Jiayi Yao <82156730+YaoJiayi@users.noreply.github.com>
jooho-XCENA pushed a commit to xcena-dev/LMCache that referenced this pull request Apr 2, 2026
* Enable layout desc in prefetch

Signed-off-by: ApostaC <yihua98@uchicago.edu>

* fix unit tests

Signed-off-by: ApostaC <yihua98@uchicago.edu>

* remove unused worker id in schedulr adapter

Signed-off-by: ApostaC <yihua98@uchicago.edu>

---------

Signed-off-by: ApostaC <yihua98@uchicago.edu>
Co-authored-by: Jiayi Yao <82156730+YaoJiayi@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

full Run comprehensive tests on this PR mp Buildkite trigger for multi-processing mode test

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants