Skip to content

Fix to support mla multiple tp failed to read issue#2697

Merged
maobaolong merged 3 commits intoLMCache:devfrom
maobaolong:mp_mla_tp
Mar 9, 2026
Merged

Fix to support mla multiple tp failed to read issue#2697
maobaolong merged 3 commits intoLMCache:devfrom
maobaolong:mp_mla_tp

Conversation

@maobaolong
Copy link
Copy Markdown
Collaborator

@maobaolong maobaolong commented Mar 5, 2026

What this PR does / why we need it:

Fix the issue when run vllm and lmcache with mp mode for deepseek mla tp=2

Clipboard_Screenshot_1772696677
  • After fixed, send two same request to verify the store and retrieve flow.
[2026-03-05 15:32:00,163] LMCache INFO: Stored 256 tokens in 0.006 seconds (server.py:294:__main__)
(APIServer pid=5070) INFO 03-05-2026-15:32:00.192 [serving.py:1354] Finished request: model=vllm_cpu_offload, request_id=896ea0c7d96c9b31, input_tokens=343, output_tokens=10, cached_tokens=0, total_tokens=353, ttft=0.099s, total_time=0.099s, decode_gen_speed=0.00tokens/s, finish_reason=length

[2026-03-05 15:32:20,788] LMCache INFO: Prefetch request completed (L1+L2): 1/1 prefix hits (1 L1, 0 L2) in 0.0 ms (request_id=-1) (storage_manager.py:357:lmcache.v1.distributed.storage_manager)
[2026-03-05 15:32:20,791] LMCache INFO: Retrieved 256 tokens in 0.001 seconds (server.py:393:__main__)
[2026-03-05 15:32:20,792] LMCache INFO: Retrieved 256 tokens in 0.000 seconds (server.py:393:__main__)

Special notes for your reviewers:

If applicable:

  • this PR contains user facing changes - docs added
  • this PR contains unit tests

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a crucial enhancement to the caching system, enabling robust support for scenarios where multiple consumers (e.g., workers in a tensor parallel setup) need to concurrently access the same cached object. By adding a count parameter to the read reservation and release mechanisms, the system can now correctly manage multiple read locks per object, preventing read failures and ensuring proper resource cleanup in complex distributed environments. This change improves the reliability and functionality of the cache for advanced model architectures.

Highlights

  • Multi-Reader Support: Introduced a count parameter to reserve_read and finish_read methods in L1Manager to allow acquiring and releasing multiple read locks for a single object key, addressing issues with multiple tensor parallel (TP) workers.
  • Storage Manager Integration: Propagated the count parameter (as num_readers) through StorageManager's submit_prefetch_task and finish_read_prefetched methods to ensure multi-reader functionality is supported at the storage layer.
  • Server-Side Logic Update: Modified the server.py logic to correctly calculate the num_readers based on GPU contexts and pass this information to the storage manager during lookup and lock release operations.
  • Comprehensive Testing: Added new test cases in test_distributed_storage_manager.py and test_l1_manager.py to thoroughly validate the new multi-reader functionality, including scenarios for partial releases and temporary object lifecycle.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • lmcache/v1/distributed/l1_manager.py
    • Added a count parameter with a default value of 1 to the reserve_read method, allowing multiple read locks to be acquired for a key.
    • Modified the reserve_read method to loop count times when acquiring read locks.
    • Added a count parameter with a default value of 1 to the finish_read method, enabling the release of multiple read locks.
    • Modified the finish_read method to loop count times when releasing read locks.
  • lmcache/v1/distributed/storage_manager.py
    • Added a count parameter with a default value of 1 to the finish_read_prefetched method and passed it to the underlying _l1_manager.finish_read call.
    • Added a num_readers parameter with a default value of 1 to the submit_prefetch_task method.
    • Passed the num_readers parameter as count to the _l1_manager.reserve_read call within submit_prefetch_task.
    • Passed the num_readers parameter as count to the _l1_manager.finish_read call for skipped keys in submit_prefetch_task.
  • lmcache/v1/multiprocess/server.py
    • Introduced a num_readers variable in the lookup method to count GPU contexts matching the model and world size.
    • Updated the submit_prefetch_task call in lookup to pass the calculated num_readers.
    • Added logic to free_lookup_locks to calculate num_readers based on model_name and world_size from the key.
    • Updated the finish_read_prefetched call in free_lookup_locks to pass the calculated num_readers as count.
  • tests/v1/distributed/test_distributed_storage_manager.py
    • Added a new test class TestStorageManagerMultiReader to cover multi-reader scenarios.
    • Implemented test_prefetch_with_num_readers to verify that submit_prefetch_task acquires the specified number of locks.
    • Implemented test_finish_read_prefetched_partial_count to confirm that partial releases leave locks held.
    • Implemented test_prefetch_skipped_keys_released_with_count to ensure non-prefix L1 hits release locks correctly.
    • Implemented test_num_readers_default_is_one to confirm default behavior matches single-reader functionality.
  • tests/v1/distributed/test_l1_manager.py
    • Added test_reserve_read_with_count_acquires_multiple_locks to verify reserve_read with count acquires multiple locks.
    • Added test_reserve_read_count_default_is_one to confirm default reserve_read behavior.
    • Added test_finish_read_with_count_releases_multiple_locks to verify finish_read with count releases multiple locks.
    • Added test_finish_read_count_partial_release to test partial lock release scenarios.
    • Added test_finish_read_count_deletes_temporary_object to ensure temporary objects are deleted when all locks are released by count.
    • Added test_finish_read_count_temp_not_deleted_until_zero to verify temporary objects survive partial count releases.
    • Added test_multi_reader_lifecycle_with_count to simulate a full multi-reader lifecycle.
    • Added test_temp_object_multi_reader_deletion to test temporary object deletion in multi-reader context.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for multiple readers on a single cache key by adding a count parameter to lock acquisition and release functions, a crucial feature for supporting tensor parallelism scenarios where multiple workers need to access the same data. However, the current implementation of lock management is susceptible to race conditions and Denial of Service (DoS) attacks. Specifically, the dynamic calculation of locks to release based on GPU contexts can lead to premature cache entry deletion or resource leaks, and the use of unvalidated counts in loops while holding global locks presents a DoS risk. It is recommended to track acquired lock counts per request and implement input validation for lock counts. Additionally, there is a suggestion to refactor some duplicated logic in server.py to improve code maintainability.

Comment thread lmcache/v1/multiprocess/server.py Outdated
Comment thread lmcache/v1/distributed/l1_manager.py Outdated
Comment thread lmcache/v1/distributed/l1_manager.py Outdated
Comment thread lmcache/v1/multiprocess/server.py Outdated
@maobaolong maobaolong added bug Something isn't working mp_mode mp Buildkite trigger for multi-processing mode test labels Mar 5, 2026
Copy link
Copy Markdown
Contributor

@ApostaC ApostaC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The big problem is the num_reader calculation. A minor issue is whether to use count or extra_count.
Otherwise LGTM!

Comment thread lmcache/v1/distributed/l1_manager.py Outdated
def reserve_read(
self,
keys: list[ObjectKey],
count: int = 1,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this become extra_count = 0, so that there won't be any misuse of this interface as reserve_read(keys, count = 0)?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this suggestion, it is better.

Comment thread lmcache/v1/distributed/l1_manager.py Outdated
def finish_read(
self,
keys: list[ObjectKey],
count: int = 1,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to have a configurable count for finish_read?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see the need! Ignore my comment above.
But I would like to have "extra_count" similar to above

Comment thread lmcache/v1/multiprocess/server.py Outdated
Comment on lines +422 to +427
num_readers += 1
if layout_desc is None:
layout_desc = get_layout_desc(
self.gpu_contexts[gpu_id],
self.chunk_size,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cannot count num_readers this way here. The example is if the LMCache is connecting to 4 vLLM instances, where each vLLM instance runs TP=2, the current implementation will count num_reader to 8.

The best way to do it is to update the protocol and pass the TP size when calling lookup. @maobaolong It may need to create a PR on the vLLM side, modifying the lmcache_mp_connector.py (need to initialize the LMCacheMPSchedulerAdapter with the TP information). I can help review the PR there, should be a very simple one.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment thread lmcache/v1/multiprocess/server.py Outdated
@maobaolong
Copy link
Copy Markdown
Collaborator Author

@ApostaC Is this failure related to this PR?

It is from https://buildkite.com/lmcache/k3-correctness-test/builds/66

[INFO] Using GPU(s): 0
[INFO] Converting ShareGPT dataset to OpenAI format...
Converted /root/correctness/.ShareGPT_V3_unfiltered_cleaned_split.json -> ./shareGPT_dataset.json
[INFO] Starting BASE vLLM server on port 8000...
[INFO] Waiting for Base server readiness...
Waiting for vLLM on port 8000 (timeout=180s)...
vLLM failed to start on port 8000 within 180s
[ERROR] Base vLLM failed to start
[INFO] Stopping vLLM process (PID: 1011)
[INFO] Collecting logs into build_019cbe12-8c8c-4d0d-87bc-c52db6134f5c.log
usage: vllm serve [model_tag] [options]
vllm serve: error: argument --compilation-config/-cc: 1 validation error for CompilationConfig
level
  Unexpected keyword argument [type=unexpected_keyword_argument, input_value=0, input_type=int]
    For further information visit https://errors.pydantic.dev/2.12/v/unexpected_keyword_argument

@ApostaC
Copy link
Copy Markdown
Contributor

ApostaC commented Mar 5, 2026

@maobaolong The K3 test is the new environment we just setup. The failure should not block the merge of the PR.

Copy link
Copy Markdown
Contributor

@ApostaC ApostaC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise LGTM!

start=start,
end=end,
request_id=request_id,
tp_size=self.tp_size,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest we have the tp_size in the lookup interface instead of putting it here. The reason is that the IPCCacheEngineKey already has world_size, and having a tp_size here may confuse people.

My proposal:

  1. Edit lmcache/v1/multiprocess/protocols/engine.py and add the tp_size to the argument list of LOOKUP operation.
  2. In the adapter (this file), pass the tp_size when calling send_lmcache_request for LOOKUP and FREE_LOOKUP_LOCKS. E.g.,
        send_lmcache_request(
            self.mq_client,
            RequestType.LOOKUP,
            [key, self.tp_size],
        )
  1. In server.py (lookup and free_lookup_locks), check the tp_size argument and MLA flag, and compute the extra_count accordingly.

Copy link
Copy Markdown
Collaborator Author

@maobaolong maobaolong Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great idea, addressed.

But I did nothing for free_lookup_locks operation, since we use lookup_lock_counts to get the extra_readers other than the argument.

MAX_READ_LOCK_COUNT = 128


def _validate_extra_count(extra_count: int) -> int:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 I like this

Comment on lines +230 to +231
for _ in range(total):
entry.read_lock.lock()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have a TODO/a issue here? It will be more performant if we support a count argument in TTLLock.lock() and TTLLock.unlock(). Since TTLLock is implemented in C++ with std::atomic, having a python for loop here may impact the performance.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO comment added.

Comment thread lmcache/v1/multiprocess/server.py Outdated
# tp > world_size → extra_readers = tp - 1 (all
# TP workers retrieve the same ObjectKey).
tp = key.tp_size if key.tp_size > 1 else world_size
extra_readers = tp - 1 if tp > world_size else 0
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TP problem will only occur for MLA, this is different TP workers will share the same object only in MLA.
For non-MLA models, different TP workers will have their own "object", and therefore, only need to lock once.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing we may need to double-check: there are some codes on the vLLM side, which will pass "world_size = 1" for MLA models even with TP, and all the LMCache logic actually relies on this to do the correct MLA + TP handling (lookup, store, and retrieve) (for example, see the implementation of ipc_keys_to_object_keys, which will explode the keys based on the world_size, and it's expected to be 1 for MLA case)

We probably want to have a better semantics for world_size in the future

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ApostaC Updated the comments.

@maobaolong
Copy link
Copy Markdown
Collaborator Author

@ApostaC Thanks for the previous review, addressed you suggestions, PTAL.

@maobaolong maobaolong closed this Mar 6, 2026
@maobaolong maobaolong reopened this Mar 6, 2026
Copy link
Copy Markdown
Contributor

@ApostaC ApostaC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise LGTM as we discussed offline. Giving an approve here

Comment on lines 505 to 508
def free_lookup_locks(
self,
key: IPCCacheEngineKey,
) -> None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's have tp_size passed here so we don't need lookup_lock_counts

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Signed-off-by: baoloongmao <baoloongmao@tencent.com>
Signed-off-by: baoloongmao <baoloongmao@tencent.com>
Copy link
Copy Markdown
Contributor

@ApostaC ApostaC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the back and forth, but can we do the following things:

  • Unify the naming of extra_readers (used in finished_read_prefetched) vs extra_count (used in other functions). Let's have extra_count in all the places.

Another thing is that there are unit tests using the lookup and free_lookup_locks interface. We need to update them otherwise the UTs will fail. These include:

  • test_mq_handler_helpers.py
  • test_mq.py
  • test_cache_server.py
  • test_blend_server.py

Signed-off-by: baoloongmao <baoloongmao@tencent.com>
@maobaolong maobaolong added the full Run comprehensive tests on this PR label Mar 9, 2026
tp_size: int,
world_size: int,
) -> int:
"""Compute extra count for MLA multi-reader locking.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comments!

)
break

layout_desc = self._find_layout_desc(model_name, world_size)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only a refactor.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah

@maobaolong maobaolong enabled auto-merge (squash) March 9, 2026 06:11
Copy link
Copy Markdown
Collaborator

@chunxiaozheng chunxiaozheng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After discuss offline, LGTM!

@maobaolong maobaolong merged commit a973aa8 into LMCache:dev Mar 9, 2026
33 of 38 checks passed
@maobaolong
Copy link
Copy Markdown
Collaborator Author

@ApostaC @chunxiaozheng Thanks for your review!

shaoxiawjc pushed a commit to shaoxiawjc/LMCache that referenced this pull request Mar 11, 2026
* Fix to support mla multiple tp failed to read issue

Signed-off-by: baoloongmao <baoloongmao@tencent.com>

* Fix style

Signed-off-by: baoloongmao <baoloongmao@tencent.com>

* Fix ut and renaming

Signed-off-by: baoloongmao <baoloongmao@tencent.com>

---------

Signed-off-by: baoloongmao <baoloongmao@tencent.com>
Signed-off-by: shaoxiawjc <wjc2800@163.com>
realAaronWu pushed a commit to realAaronWu/LMCache that referenced this pull request Mar 20, 2026
* Fix to support mla multiple tp failed to read issue

Signed-off-by: baoloongmao <baoloongmao@tencent.com>

* Fix style

Signed-off-by: baoloongmao <baoloongmao@tencent.com>

* Fix ut and renaming

Signed-off-by: baoloongmao <baoloongmao@tencent.com>

---------

Signed-off-by: baoloongmao <baoloongmao@tencent.com>
Signed-off-by: Aaron Wu <aaron.wu@dell.com>
jooho-XCENA pushed a commit to xcena-dev/LMCache that referenced this pull request Apr 2, 2026
* Fix to support mla multiple tp failed to read issue

Signed-off-by: baoloongmao <baoloongmao@tencent.com>

* Fix style

Signed-off-by: baoloongmao <baoloongmao@tencent.com>

* Fix ut and renaming

Signed-off-by: baoloongmao <baoloongmao@tencent.com>

---------

Signed-off-by: baoloongmao <baoloongmao@tencent.com>
jooho-XCENA pushed a commit to xcena-dev/LMCache that referenced this pull request Apr 2, 2026
* Fix to support mla multiple tp failed to read issue

Signed-off-by: baoloongmao <baoloongmao@tencent.com>

* Fix style

Signed-off-by: baoloongmao <baoloongmao@tencent.com>

* Fix ut and renaming

Signed-off-by: baoloongmao <baoloongmao@tencent.com>

---------

Signed-off-by: baoloongmao <baoloongmao@tencent.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working full Run comprehensive tests on this PR mp_mode mp Buildkite trigger for multi-processing mode test

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants