Skip to content

feat(kv_cache): enable asymmetric store/retrieve storages in PD backend#2509

Merged
deng451e merged 28 commits intoLMCache:devfrom
hlin99:PD_save_decode
Mar 16, 2026
Merged

feat(kv_cache): enable asymmetric store/retrieve storages in PD backend#2509
deng451e merged 28 commits intoLMCache:devfrom
hlin99:PD_save_decode

Conversation

@hlin99
Copy link
Copy Markdown
Contributor

@hlin99 hlin99 commented Jan 29, 2026

Remove the restriction that prevented using save_decode_cache and remote_backend simultaneously in Prefill-Decode (PD) separation scenarios.

This change introduces pd_retrieve_locations and pd_store_location parameters to decouple the KV cache retrieval and storage logic. This enables an asymmetric cache flow:

  1. Prefill nodes transmit KV cache to Decode nodes via the PDBackend.
  2. Decode nodes write back their generated KV cache to a remote backend for subsequent prefill reuse.
  3. In multi-turn dialogue scenarios, subsequent prefill requests retrieve historical KV cache from the remote backend, significantly increasing Prefix Cache hit rates and reducing TTFT

This decoupling provides greater flexibility for cross-instance cache management and improves overall pipeline efficiency in distributed inference.

Gemini_Generated_Image_ue7nxmue7nxmue7n

Workflow:

  1. Prefill -> Decode (PDBackend): Initial KV transfer for the current turn.
  2. Decode -> Remote (Store): Decode saves updated KV to NFS for persistence.
  3. Remote -> Prefill (Retrieve): Next-turn prefill pulls from Remote, drastically increasing Prefix Cache hit rate for multi-turn dialogues.

Remove the restriction that prevented using `save_decode_cache` and
`remote_backend` simultaneously in Prefill-Decode (PD) separation scenarios.

This change introduces `pd_retrieve_locations` and `pd_store_location`
parameters to decouple the KV cache retrieval and storage logic. This
enables an asymmetric cache flow:
1. Prefill nodes transmit KV cache to Decode nodes via the PDBackend.
2. Decode nodes write back their generated KV cache to a remote backend
   for subsequent prefill reuse.
3. In multi-turn dialogue scenarios, subsequent
   prefill requests retrieve historical KV cache from the remote backend,
   significantly increasing Prefix Cache hit rates and reducing TTFT

This decoupling provides greater flexibility for cross-instance cache
management and improves overall pipeline efficiency in distributed
inference.

[ Compute Layer ]
    +----------------------+                 +------------------+
    |  Prefill Node        | ===============>|  Decode Node     |
    | (Hit-Remote & GenKV) |  (1) PDBackend  | (Hit-PD & GenKV) |
    +-------^--------------+                 +-------+----------+
            |                                   |
            :                                   :
------------|-----------------------------------|------------
[ Storage Layer ]                               |
            |                                   | (2) pd_store_location
            | (3) pd_retrieve_locations         |     (Decode -> Pool)
            |     (Pool -> Prefill)             |
            |                                   v
    +-------+--------------------------------------------+
    |             Distributed Storage Pool               |
    |   [Node A]    [Node B]    [Node C]    [Node D]     |
    |   <=======  (Object Storage / NFS / DFS)  =======> |
    +----------------------------------------------------+

Workflow:
1. Prefill -> Decode (PDBackend): Initial KV transfer for the current turn.
2. Decode -> Remote (Store): Decode saves updated KV to NFS for persistence.
3. Remote -> Prefill (Retrieve): Next-turn prefill pulls from Remote,
   drastically increasing Prefix Cache hit rate for multi-turn dialogues.

Signed-off-by: Tony Lin <tony.lin@intel.com>
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @hlin99, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the Prefill-Decode (PD) backend's flexibility by introducing new configuration options that allow for asymmetric management of the Key-Value (KV) cache. By decoupling the storage and retrieval mechanisms, the system can now leverage remote backends more effectively, particularly for multi-turn dialogue scenarios. This change is designed to improve overall pipeline efficiency, boost Prefix Cache hit rates, and reduce the Time To First Token (TTFT) by facilitating the reuse of historical KV cache across different stages and instances.

Highlights

  • Asymmetric KV Cache Management: Introduced pd_retrieve_locations and pd_store_location configuration parameters to decouple KV cache retrieval and storage logic in the Prefill-Decode (PD) backend.
  • Removed Restrictions: Eliminated previous limitations that prevented the simultaneous use of save_decode_cache and remote_backend in PD separation scenarios.
  • Improved Multi-Turn Dialogue Efficiency: Enabled a workflow where Decode nodes can write generated KV cache to a remote backend, and subsequent Prefill requests can retrieve historical KV cache from this remote backend, significantly increasing Prefix Cache hit rates and reducing Time To First Token (TTFT).

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant feature to enable asymmetric save/remote storage in the PD backend, decoupling KV cache retrieval and storage logic. The changes in lmcache/v1/config.py correctly add the new pd_retrieve_locations and pd_store_location configurations and remove previous assertions that restricted this functionality. In lmcache/v1/storage_backend/storage_manager.py, these new configurations are integrated to serve as default locations for various cache operations. The implementation is logical and well-aligned with the detailed description. I have one suggestion for improvement in storage_manager.py to handle configuration initialization more robustly. Overall, this is a solid contribution that enhances flexibility for distributed inference.

Comment thread lmcache/v1/storage_backend/storage_manager.py Outdated
Signed-off-by: Tony Lin <tony.lin@intel.com>
@hlin99 hlin99 changed the title feat(kv_cache): enable asymmetric save/remote storage in PD backend feat(kv_cache): enable asymmetric save/restore storage in PD backend Jan 29, 2026
@hlin99 hlin99 changed the title feat(kv_cache): enable asymmetric save/restore storage in PD backend feat(kv_cache): enable asymmetric save/retrieve storages in PD backend Jan 29, 2026
@hlin99 hlin99 changed the title feat(kv_cache): enable asymmetric save/retrieve storages in PD backend feat(kv_cache): enable asymmetric store/retrieve storages in PD backend Jan 29, 2026
@hlin99
Copy link
Copy Markdown
Contributor Author

hlin99 commented Feb 3, 2026

some additional info from multi-round chat:
Screenshot 2026-02-03 160507

ChatGPT Image Feb 5, 2026, 09_18_29 AM

logs:

1st round

[APP]
App INFO: current_context_ids len = 0, new_ids len = 13, send_ids len = 13
[Prefill]
No hit - LMCache INFO: Reqid: cmpl-451add6f-2193-4d44-a147-7752745a6483-0, Total tokens 13, LMCache hit tokens: 0, need to load: 0 (vllm_v1_adapter.py:1619:lmcache.integration.vllm.vllm_v1_adapter)
Store KV in PD backends - LMCache INFO: Stored 13 out of total 13 tokens. size: 0.0084 gb, cost 192.3073 ms, throughput: 0.0436 GB/s; offload_time: 29.0620 ms, put_time: 163.2453 ms (cache_engine.py:459:lmcache.v1.cache_engine)
[Decode]
Hit all tokens from PD backends - LMCache INFO: Retrieved 13 out of 13 required tokens (from 13 total tokens). size: 0.0084 gb, cost 22.9585 ms, throughput: 0.3649 GB/s; (cache_engine.py:755:lmcache.v1.cache_engine)
Processing and store KV in remote backends - LMCache INFO: Storing KV cache for 256 out of 1792 tokens (skip_leading_tokens=1536) for request cmpl-451add6f-2193-4d44-a147-7752745a6483-0 (vllm_v1_adapter.py:1385:lmcache.integration.vllm.vllm_v1_adapter)

2nd round
[APP]
App INFO: current_context_ids len = 2033, new_ids len = 13, send_ids len = 2046
[Prefill]
Hit 1792 tokens@256 chunk size boundary - LMCache INFO: Retrieved 1792 out of 1792 required tokens (from 1792 total tokens). size: 0.0000 gb, cost 68.0281 ms, throughput: 0.0000 GB/s; (cache_engine.py:755:lmcache.v1.cache_engine)
Store KV in PD backends - LMCache INFO: Stored 2046 out of total 2046 tokens. size: 0.0670 gb, cost 188.3272 ms, throughput: 0.3559 GB/s; offload_time: 172.5690 ms, put_time: 15.7581 ms (cache_engine.py:459:lmcache.v1.cache_engine)
[Decode]
Hit all tokens from PD backends - LMCache INFO: Reqid: cmpl-4eaf4805-58c9-451f-9b48-a1dd4c4b3b6c-0, Total tokens 2046, LMCache hit tokens: 2046, need to load: 125 (vllm_v1_adapter.py:1619:lmcache.integration.vllm.vllm_v1_adapter)
Processing and store KV in remote backends - LMCache INFO: Storing KV cache for 256 out of 3840 tokens (skip_leading_tokens=3584) for request cmpl-4eaf4805-58c9-451f-9b48-a1dd4c4b3b6c-0 (vllm_v1_adapter.py:1385:lmcache.integration.vllm.vllm_v1_adapter)

3rd round

[APP]
App INFO: current_context_ids len = 4066, new_ids len = 13, send_ids len = 4079
[Prefill]
Hit 3840 tokens@256 chunk size boundary - LMCache INFO: Reqid: cmpl-4deb72bb-7cb5-4757-993e-26086b0c808f-0, Total tokens 4079, LMCache hit tokens: 3840, need to load: 3840 (vllm_v1_adapter.py:1619:lmcache.integration.vllm.vllm_v1_adapter)
Store KV in PD backends - LMCache INFO: Stored 4079 out of total 4079 tokens. size: 0.1340 gb, cost 336.0984 ms, throughput: 0.3988 GB/s; offload_time: 317.9360 ms, put_time: 18.1624 ms (cache_engine.py:459:lmcache.v1.cache_engine)
[Decode]
Hit all tokens from PD backends - LMCache INFO: Reqid: cmpl-4deb72bb-7cb5-4757-993e-26086b0c808f-0, Total tokens 4079, LMCache hit tokens: 4079, need to load: 110 (vllm_v1_adapter.py:1619:lmcache.integration.vllm.vllm_v1_adapter)
Processing and store KV in remote backends - Storing KV cache for 256 out of 5888 tokens (skip_leading_tokens=5632) for request cmpl-4deb72bb-7cb5-4757-993e-26086b0c808f-0 (vllm_v1_adapter.py:1385:lmcache.integration.vllm.vllm_v1_adapter)

4th round

[APP]
App INFO: current_context_ids len = 6099, new_ids len = 13, send_ids len = 6112
[Prefill]
Hit 5888 tokens@256 chunk size boundary - LMCache INFO: Reqid: cmpl-928ef16e-6a9d-409e-978d-44907e50c3d1-0, Total tokens 6112, LMCache hit tokens: 5888, need to load: 5888 (vllm_v1_adapter.py:1619:lmcache.integration.vllm.vllm_v1_adapter)
Store KV in PD backends - LMCache INFO: Stored 6112 out of total 6112 tokens. size: 0.2010 gb, cost 475.6424 ms, throughput: 0.4227 GB/s; offload_time: 455.6727 ms, put_time: 19.9696 ms (cache_engine.py:459:lmcache.v1.cache_engine)
[Decode]
Hit all tokens from PD backends - LMCache INFO: Reqid: cmpl-928ef16e-6a9d-409e-978d-44907e50c3d1-0, Total tokens 6112, LMCache hit tokens: 6112, need to load: 95 (vllm_v1_adapter.py:1619:lmcache.integration.vllm.vllm_v1_adapter)
Processing and store KV in remote backends - Storing KV cache for 256 out of 7936 tokens (skip_leading_tokens=7680) for request cmpl-928ef16e-6a9d-409e-978d-44907e50c3d1-0 (vllm_v1_adapter.py:1385:lmcache.integration.vllm.vllm_v1_adapter)

@sammshen sammshen added the full Run comprehensive tests on this PR label Feb 5, 2026
@sammshen
Copy link
Copy Markdown
Contributor

sammshen commented Feb 5, 2026

did you turn SAVE_DECODE_CACHE on for the decoder? that's a prerequisite for this solution to be helpful

@sammshen
Copy link
Copy Markdown
Contributor

sammshen commented Feb 5, 2026

can you run a minimal test (using the same request twice wtih local_cpu: false) showing prefill is retrieving decoded tokens from remote?

@hlin99
Copy link
Copy Markdown
Contributor Author

hlin99 commented Feb 5, 2026

did you turn SAVE_DECODE_CACHE on for the decoder? that's a prerequisite for this solution to be helpful

yes. SAVE_DECODE_CACHE is turned on. i also add a config exmaple in the PR, please refer to https://github.com/LMCache/LMCache/pull/2509/changes#diff-071ef93157304173f68c621a88bf5a14394f9bd60c54bc502d061e2aee87b0dd

the only config that is not in the PR is: the decode kv_role needs to be kv_both, rather than current kv_consumer

PD backend is responsible for P->D
as long as another backend can be used for D->P, this solution will work. The another backend depends on the HW hierarchy......but to be simple, remote backend with high speed network and synchronization mechanism(if it's distributed storage) is definitely a good choice.

@hlin99
Copy link
Copy Markdown
Contributor Author

hlin99 commented Feb 6, 2026

can you run a minimal test (using the same request twice wtih local_cpu: false) showing prefill is retrieving decoded tokens from remote?

Hi Samuel,

I ran a similar test and collected logs for a multi-round conversation with LocalCPUBackend enabled.
Below are the observations based on log analysis. Let me know if this data is sufficient.
If you’d also like to verify the CPU-off case, I can re-run it.


Environment

  • Local CPU backend: ON
  • Scenario: multi-round conversation
  • Focus: prefill / decode cache behavior across rounds

Round 1 (cold start)

  • New conversation, no prefill cache hit
  • 14 tokens generated and stored into PDBackend (then propagated to other backends)

Summary:

  • Prefill hit: 0
  • Tokens stored: 14
  • PDBackend populated after processing

Prefill logs:
LMCache INFO: <batched_contains> hit_chunks: 0, backend_name: RemoteBackend, chunk_hashes: [-4977671964338549035] (storage_manager.py:872:lmcache.v1.storage_backend.storage_manager)
(EngineCore_DP0 pid=115197) [2026-02-06 04:18:17,530] LMCache INFO: Reqid: cmpl-713f358a-f394-45fc-b129-b50bd9d0e92d-0, Total tokens 14, LMCache hit tokens: 0, need to load: 0 (vllm_v1_adapter.py:1612:lmcache.integration.vllm.vllm_v1_adapter)
LMCache INFO: <batched_put> backend_name: PDBackend, chunk_hashes: [-4977671964338549035] (storage_manager.py:413:lmcache.v1.storage_backend.storage_manager)
LMCache INFO: <batched_put> backend_name: LocalCPUBackend, chunk_hashes: [-4977671964338549035] (storage_manager.py:413:lmcache.v1.storage_backend.storage_manager)
LMCache INFO: <batched_put> backend_name: RemoteBackend, chunk_hashes: [-4977671964338549035] (storage_manager.py:413:lmcache.v1.storage_backend.storage_manager)
LMCache INFO: Stored 14 out of total 14 tokens. size: 0.0084 gb, cost 172.1476 ms, throughput: 0.0487 GB/s; offload_time: 6.1441 ms, put_time: 166.0035 ms (cache_engine.py:453:lmcache.v1.cache_engine)

Decode:

  • Retrieved 14 tokens from PDBackend
  • Generated tokens stored to remote backends

Decode logs:
LMCache INFO: <batched_contains> hit_chunks: 1, backend_name: PDBackend, chunk_hashes: [-4977671964338549035] (storage_manager.py:872:lmcache.v1.storage_backend.storage_manager)
LMCache INFO: Retrieved 14 out of 14 required tokens (from 14 total tokens). size: 0.0000 gb, cost 28.6021 ms, throughput: 0.0000 GB/s; (cache_engine.py:749:lmcache.v1.cache_engine)


Round 2

  • Prefill hits 1280 / 1424 tokens
  • All hits come from RemoteBackend
  • Hit chunks are migrated to LocalCPUBackend per LMCache logic

Summary:

  • Prefill hit: 1280 / 1424
  • Source backend: RemoteBackend
  • Action: RemoteBackend → LocalCPUBackend migration

Prefill logs:
LMCache INFO: <batched_contains> hit_chunks: 0, backend_name: LocalCPUBackend, chunk_hashes: [2758913087439748001, 2012401185824684720, -199437085030961, -128674787004829214, 4166364048967685432, -229250364596124979] (storage_manager.py:872:lmcache.v1.storage_backend.storage_manager)
LMCache INFO: <batched_contains> hit_chunks: 5, backend_name: RemoteBackend, chunk_hashes: [2758913087439748001, 2012401185824684720, -199437085030961, -128674787004829214, 4166364048967685432, -229250364596124979] (storage_manager.py:872:lmcache.v1.storage_backend.storage_manager)
LMCache INFO: Reqid: cmpl-f2cf9645-42b8-49d5-919d-2a4429a4a44a-0, Total tokens 1424, LMCache hit tokens: 1280, need to load: 1280 (vllm_v1_adapter.py:1612:lmcache.integration.vllm.vllm_v1_adapter)
LMCache INFO: <batched_get> backend_name: RemoteBackend, chunk_hashes: [2758913087439748001, 2012401185824684720, -199437085030961, -128674787004829214, 4166364048967685432] (storage_manager.py:495:lmcache.v1.storage_backend.storage_manager)
LMCache INFO: Storing 5 objects from RemoteBackend to LocalCPUBackend (storage_manager.py:503:lmcache.v1.storage_backend.storage_manager)


Round 3

  • Prefill hits 1536 / 1805 tokens
    • 1790 decoded tokens
    • 15 new prompt tokens
  • Hits come from LocalCPUBackend and RemoteBackend
  • Remaining remote chunks are migrated to LocalCPUBackend

Summary:

  • Prefill hit: 1536 / 1805
  • Source backends: LocalCPUBackend + RemoteBackend
  • Action: Final migration to LocalCPUBackend

Prefill logs:
LMCache INFO: <batched_contains> hit_chunks: 5, backend_name: LocalCPUBackend, chunk_hashes: [2758913087439748001, 2012401185824684720, -199437085030961, -128674787004829214, 4166364048967685432, 1669638256940260353, 1673383803642939430, 3021582200251864517] (storage_manager.py:872:lmcache.v1.storage_backend.storage_manager)
LMCache INFO: <batched_contains> hit_chunks: 1, backend_name: RemoteBackend, chunk_hashes: [1669638256940260353, 1673383803642939430, 3021582200251864517] (storage_manager.py:872:lmcache.v1.storage_backend.storage_manager)
LMCache INFO: Reqid: cmpl-c9940446-c42f-449e-baa0-8e0e97e3daeb-0, Total tokens 1805, LMCache hit tokens: 1536, need to load: 1536 (vllm_v1_adapter.py:1612:lmcache.integration.vllm.vllm_v1_adapter)
LMCache INFO: <batched_get> backend_name: LocalCPUBackend, chunk_hashes: [2758913087439748001, 2012401185824684720, -199437085030961, -128674787004829214, 4166364048967685432] (storage_manager.py:495:lmcache.v1.storage_backend.storage_manager)
LMCache INFO: <batched_get> backend_name: RemoteBackend, chunk_hashes: [1669638256940260353] (storage_manager.py:495:lmcache.v1.storage_backend.storage_manager)
LMCache INFO: Storing 1 objects from RemoteBackend to LocalCPUBackend (storage_manager.py:503:lmcache.v1.storage_backend.storage_manager)
LMCache INFO: Retrieved 1536 out of 1536 required tokens (from 1536 total tokens). size: 0.0000 gb, cost 35.6410 ms, throughput: 0.0000 GB/s; (cache_engine.py:749:lmcache.v1.cache_engine)


Conclusion

  • Cache warm-up behaves as expected across conversation rounds
  • Tokens correctly migrate from D to P via RemoteBackend
  • Prefill hit ratio increases consistently in later rounds

@hlin99
Copy link
Copy Markdown
Contributor Author

hlin99 commented Feb 6, 2026

Hi @sammshen I tried to turn off CPU backend and leave remote backend as you required, but eventually I understood CPU must be on together with remote backend. and force turning CPU off will encounter a separate issue (#2556), which I fixed by #2557

I think another PR for different issue is preferred, but if you would like to merge it into this one. let me know.

Great thanks for your time to review PRs !

@hlin99
Copy link
Copy Markdown
Contributor Author

hlin99 commented Feb 13, 2026

hi @sammshen thanks for your time to review. is there any other question or concern that need me to address?

@sammshen
Copy link
Copy Markdown
Contributor

sammshen commented Feb 13, 2026

@feixiangpeng PTAL

@deng451e
Copy link
Copy Markdown
Collaborator

Hi @hlin99 ,

Thank you for the proposal — I appreciate the idea and the effort behind it. I do have a few concerns regarding the current design and its impact on end-to-end performance.
While this approach improve cache hit rates on the prefilling nodes for multi-round conversations, it does not necessarily translate into overall performance gains.

  1. Additional KV traversal overhead
    Storing KV (by the decoder) and retrieving it (by the prefiller) through a remote storage backend introduces extra latency, as it adds additional KV cache traversal hops in the critical path.
  2. Increasing memory-boundedness with longer histories
    As the conversation history grows, the ratio of new KV to compute to reused KV cache decreases, making these prefill workloads increasingly memory-bound. Batching such memory-bound prefill tasks together with compute-bound prefill tasks undermines the original goal of PD disaggregation.

In a distributed serving setup, it may be more effective to adaptively dispatch these memory-bound prefill requests directly to decoder nodes, where they can be better batched with similar workloads and where cache hits can be populated locally without remote traversal overhead.

This could preserve both cache efficiency and the intended performance benefits of disaggregated execution.

@hlin99
Copy link
Copy Markdown
Contributor Author

hlin99 commented Feb 14, 2026

deng451e
hi @deng451e thanks for the comments. let me address your concerns

  1. agree that storing kv cache on decode does introduce latency, however considering decode phase always takes seconds or minutes to complete (depending on scenario), the milisecond time from storing kv is reletively very trival. prefill only needs to store additional kv on top of decode, so vs the compute time & kv transfer over nixl, i think that won't affect TTFT a lot

additionally, in your pd system, do you turn on prefix-caching on prefill?

  1. you're right. As the conversation history grows, the ratio of new KV to compute to reused KV cache decreases, that's the scenario where we take advantage of this solution to improve TTFT a lot. distributing such request directly to decode node is possible but it introduces a lot of engineering problems. i.e. a) the prefill on decode will increase ITL, which is another dimension of performance. b) distrubuting the request to the right decode node itself is challenge in a system where decode nodes might be a couple of and massive requests & workloads may make it even harder for decode to maintain these KV for long time. c) handle prefill on decode node may require to reserve additional GPU memory(prefill usually consumes more than decode) and this more reserved memory decreases KV pool, then decreases decode batch size, then decreases overall throughput.

in addition, hitting those KV on decode nodes also requires this patch as now save_decode_cache is off when PD is turned on, which means decode won't save any KV in any place.

in the last, i want to highlight that this solution DOES NOT break anything we have and it just offers additional options. people can decide to turn it on or not depending on their platform. remote backend is just my example and it doesn't mean it's the only option. if you look at the PR, people can select other backend as they want for their platform/system. and let's not assume remote backend is always slow. AFAIK, many CSPs they have very fast distributed storage system for these KV storage, and we pay very small overhead (sevaral ms) to buy big value(save seconds)

@hlin99
Copy link
Copy Markdown
Contributor Author

hlin99 commented Feb 14, 2026

in coding scenario. for the N th round conversation, all N-2 round KV are hit in local cpu backend(assume the capacity is big enough) and N-1 round KV are hit in remote backend. so majority KV is retrieved from faster backend. hope this eases the concern.

storage manager) or has been stored (handled by storage backend).
"""
# The dictionary from backend cname to objects and keys
if location is None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is location None if and only if pd?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to minimize the scope of this PR, right now the retrieve and store locations are effective only
when pd backend is turned on


# Search all backends for blocking get
for backend_name, backend in self.get_active_storage_backends(location):
for backend_name, backend in self.get_active_storage_backends(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we passin gin the pd_retrieve_locations in the general get?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to leverage what we have in current design to avoid bigger arch level changes

self.pd_store_location = None

if self.enable_pd:
# these para are only effective under PD
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about generalizing retrieve_locations and store_locations since you're passing it into get_active_storage_backends anyways?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I have two concerns that current changes are only effective under PD. But I am happy to extend the scope of my PR if others do not have such concerns like me.

  1. To minimize the change, to avoid potential regressions, and to avoid any arch level changes, I leverage the interface lmcache already has today to break the limitations of not allowing save decode tokens. Making sure we have additional features and not breaking anything we have today
  2. The retrieve location and store locations are from general get and store, so it could not limited to PD only. But I am not quite sure if there any story about these params and a little worried about config those params in general cases.

If the suggestion is to open these retrieve and store locations in general, I am happy to do it and will update the PR. Please advise. Thx. @sammshen

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is unclean to have something PD specific in the general purpose get() and store()

would it be possible to pass int he location every time we call storage_manager.get() and .store() in the PD case

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it. Thx for the good suggestion. Will update it

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sammshen hi, Sam, updated the PR with changes according to your good suggestions. could you take a look again? thx.

hlin99 added 2 commits March 9, 2026 19:00
Signed-off-by: Tony Lin <tony.lin@intel.com>
Signed-off-by: Tony Lin <tony.lin@intel.com>
@hlin99
Copy link
Copy Markdown
Contributor Author

hlin99 commented Mar 9, 2026

hi @sammshen the PR has 2 approvals, but looks like we need the 2nd committer appoval to meet the merge requirement? can you help assign to the right person? many thanks.

@sammshen
Copy link
Copy Markdown
Contributor

sammshen commented Mar 9, 2026

@DongDongJu would you like to take a quick look at this PR?

@DongDongJu
Copy link
Copy Markdown
Collaborator

Sorry for late notice. Let me check it now

Copy link
Copy Markdown
Collaborator

@DongDongJu DongDongJu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @hlin99,
Thanks for this updates!
From my understanding, synchronous retrieve path's follow up patch for _process_tokens_internal().
but this pr make enough small patches for utilize the pd backend and remote backend at the same time.
Generally looks good to me. I left just one comment.

Comment thread lmcache/v1/config.py Outdated
hlin99 added 3 commits March 13, 2026 00:22
Signed-off-by: Tony Lin <tony.lin@intel.com>
Signed-off-by: Tony Lin <tony.lin@intel.com>
@hlin99
Copy link
Copy Markdown
Contributor Author

hlin99 commented Mar 13, 2026

hi @sammshen already rebased to the latest and apart from k3 fail which is expected & unit-test failure which seems irrelvant to the PR, all the other CIs got passed. let me know if anything is missed to merge this PR? thanks.

@deng451e deng451e enabled auto-merge (squash) March 16, 2026 05:36
@deng451e deng451e merged commit 9d41318 into LMCache:dev Mar 16, 2026
26 of 29 checks passed
hyunyul-XCENA pushed a commit to xcena-dev/LMCache that referenced this pull request Mar 20, 2026
…nd (LMCache#2509)

* feat(kv_cache): enable asymmetric save/remote storage in PD backend

Remove the restriction that prevented using `save_decode_cache` and
`remote_backend` simultaneously in Prefill-Decode (PD) separation scenarios.

This change introduces `pd_retrieve_locations` and `pd_store_location`
parameters to decouple the KV cache retrieval and storage logic. This
enables an asymmetric cache flow:
1. Prefill nodes transmit KV cache to Decode nodes via the PDBackend.
2. Decode nodes write back their generated KV cache to a remote backend
   for subsequent prefill reuse.
3. In multi-turn dialogue scenarios, subsequent
   prefill requests retrieve historical KV cache from the remote backend,
   significantly increasing Prefix Cache hit rates and reducing TTFT

This decoupling provides greater flexibility for cross-instance cache
management and improves overall pipeline efficiency in distributed
inference.

[ Compute Layer ]
    +----------------------+                 +------------------+
    |  Prefill Node        | ===============>|  Decode Node     |
    | (Hit-Remote & GenKV) |  (1) PDBackend  | (Hit-PD & GenKV) |
    +-------^--------------+                 +-------+----------+
            |                                   |
            :                                   :
------------|-----------------------------------|------------
[ Storage Layer ]                               |
            |                                   | (2) pd_store_location
            | (3) pd_retrieve_locations         |     (Decode -> Pool)
            |     (Pool -> Prefill)             |
            |                                   v
    +-------+--------------------------------------------+
    |             Distributed Storage Pool               |
    |   [Node A]    [Node B]    [Node C]    [Node D]     |
    |   <=======  (Object Storage / NFS / DFS)  =======> |
    +----------------------------------------------------+

Workflow:
1. Prefill -> Decode (PDBackend): Initial KV transfer for the current turn.
2. Decode -> Remote (Store): Decode saves updated KV to NFS for persistence.
3. Remote -> Prefill (Retrieve): Next-turn prefill pulls from Remote,
   drastically increasing Prefix Cache hit rate for multi-turn dialogues.

Signed-off-by: Tony Lin <tony.lin@intel.com>

* small refactor

Signed-off-by: Tony Lin <tony.lin@intel.com>

* config examples for pd + remote backends

Signed-off-by: Tony Lin <tony.lin@intel.com>

* refactor: rename pd_retrieve_locations/pd_store_location to retrieve_locations/store_location

Remove the PD-specific prefix to make the retrieve/store locations
generic instead of being limited to PD only.

This breaks the PD-only feature restriction and allows the mechanism
to be reused by other roles/components.

Signed-off-by: Tony Lin <tony.lin@intel.com>

* move retrieve & store locations from storage manger to cache engine

Signed-off-by: Tony Lin <tony.lin@intel.com>

* add para validation check

Signed-off-by: Tony Lin <tony.lin@intel.com>

* config: replace hardcoded IP with placeholder in decoder remote configs

Signed-off-by: Tony Lin <tony.lin@intel.com>

* resolve conflicts and rebase to the latest

Signed-off-by: Tony Lin <tony.lin@intel.com>

* address review comments

Signed-off-by: Tony Lin <tony.lin@intel.com>

* add description in configurations.rst

Signed-off-by: Tony Lin <tony.lin@intel.com>

---------

Signed-off-by: Tony Lin <tony.lin@intel.com>
Co-authored-by: deng451e <57919305+deng451e@users.noreply.github.com>
realAaronWu pushed a commit to realAaronWu/LMCache that referenced this pull request Mar 20, 2026
…nd (LMCache#2509)

* feat(kv_cache): enable asymmetric save/remote storage in PD backend

Remove the restriction that prevented using `save_decode_cache` and
`remote_backend` simultaneously in Prefill-Decode (PD) separation scenarios.

This change introduces `pd_retrieve_locations` and `pd_store_location`
parameters to decouple the KV cache retrieval and storage logic. This
enables an asymmetric cache flow:
1. Prefill nodes transmit KV cache to Decode nodes via the PDBackend.
2. Decode nodes write back their generated KV cache to a remote backend
   for subsequent prefill reuse.
3. In multi-turn dialogue scenarios, subsequent
   prefill requests retrieve historical KV cache from the remote backend,
   significantly increasing Prefix Cache hit rates and reducing TTFT

This decoupling provides greater flexibility for cross-instance cache
management and improves overall pipeline efficiency in distributed
inference.

[ Compute Layer ]
    +----------------------+                 +------------------+
    |  Prefill Node        | ===============>|  Decode Node     |
    | (Hit-Remote & GenKV) |  (1) PDBackend  | (Hit-PD & GenKV) |
    +-------^--------------+                 +-------+----------+
            |                                   |
            :                                   :
------------|-----------------------------------|------------
[ Storage Layer ]                               |
            |                                   | (2) pd_store_location
            | (3) pd_retrieve_locations         |     (Decode -> Pool)
            |     (Pool -> Prefill)             |
            |                                   v
    +-------+--------------------------------------------+
    |             Distributed Storage Pool               |
    |   [Node A]    [Node B]    [Node C]    [Node D]     |
    |   <=======  (Object Storage / NFS / DFS)  =======> |
    +----------------------------------------------------+

Workflow:
1. Prefill -> Decode (PDBackend): Initial KV transfer for the current turn.
2. Decode -> Remote (Store): Decode saves updated KV to NFS for persistence.
3. Remote -> Prefill (Retrieve): Next-turn prefill pulls from Remote,
   drastically increasing Prefix Cache hit rate for multi-turn dialogues.

Signed-off-by: Tony Lin <tony.lin@intel.com>

* small refactor

Signed-off-by: Tony Lin <tony.lin@intel.com>

* config examples for pd + remote backends

Signed-off-by: Tony Lin <tony.lin@intel.com>

* refactor: rename pd_retrieve_locations/pd_store_location to retrieve_locations/store_location

Remove the PD-specific prefix to make the retrieve/store locations
generic instead of being limited to PD only.

This breaks the PD-only feature restriction and allows the mechanism
to be reused by other roles/components.

Signed-off-by: Tony Lin <tony.lin@intel.com>

* move retrieve & store locations from storage manger to cache engine

Signed-off-by: Tony Lin <tony.lin@intel.com>

* add para validation check

Signed-off-by: Tony Lin <tony.lin@intel.com>

* config: replace hardcoded IP with placeholder in decoder remote configs

Signed-off-by: Tony Lin <tony.lin@intel.com>

* resolve conflicts and rebase to the latest

Signed-off-by: Tony Lin <tony.lin@intel.com>

* address review comments

Signed-off-by: Tony Lin <tony.lin@intel.com>

* add description in configurations.rst

Signed-off-by: Tony Lin <tony.lin@intel.com>

---------

Signed-off-by: Tony Lin <tony.lin@intel.com>
Co-authored-by: deng451e <57919305+deng451e@users.noreply.github.com>
Signed-off-by: Aaron Wu <aaron.wu@dell.com>
deng451e pushed a commit to deng451e/LMCache that referenced this pull request Mar 21, 2026
…nd (LMCache#2509)

* feat(kv_cache): enable asymmetric save/remote storage in PD backend

Remove the restriction that prevented using `save_decode_cache` and
`remote_backend` simultaneously in Prefill-Decode (PD) separation scenarios.

This change introduces `pd_retrieve_locations` and `pd_store_location`
parameters to decouple the KV cache retrieval and storage logic. This
enables an asymmetric cache flow:
1. Prefill nodes transmit KV cache to Decode nodes via the PDBackend.
2. Decode nodes write back their generated KV cache to a remote backend
   for subsequent prefill reuse.
3. In multi-turn dialogue scenarios, subsequent
   prefill requests retrieve historical KV cache from the remote backend,
   significantly increasing Prefix Cache hit rates and reducing TTFT

This decoupling provides greater flexibility for cross-instance cache
management and improves overall pipeline efficiency in distributed
inference.

[ Compute Layer ]
    +----------------------+                 +------------------+
    |  Prefill Node        | ===============>|  Decode Node     |
    | (Hit-Remote & GenKV) |  (1) PDBackend  | (Hit-PD & GenKV) |
    +-------^--------------+                 +-------+----------+
            |                                   |
            :                                   :
------------|-----------------------------------|------------
[ Storage Layer ]                               |
            |                                   | (2) pd_store_location
            | (3) pd_retrieve_locations         |     (Decode -> Pool)
            |     (Pool -> Prefill)             |
            |                                   v
    +-------+--------------------------------------------+
    |             Distributed Storage Pool               |
    |   [Node A]    [Node B]    [Node C]    [Node D]     |
    |   <=======  (Object Storage / NFS / DFS)  =======> |
    +----------------------------------------------------+

Workflow:
1. Prefill -> Decode (PDBackend): Initial KV transfer for the current turn.
2. Decode -> Remote (Store): Decode saves updated KV to NFS for persistence.
3. Remote -> Prefill (Retrieve): Next-turn prefill pulls from Remote,
   drastically increasing Prefix Cache hit rate for multi-turn dialogues.

Signed-off-by: Tony Lin <tony.lin@intel.com>

* small refactor

Signed-off-by: Tony Lin <tony.lin@intel.com>

* config examples for pd + remote backends

Signed-off-by: Tony Lin <tony.lin@intel.com>

* refactor: rename pd_retrieve_locations/pd_store_location to retrieve_locations/store_location

Remove the PD-specific prefix to make the retrieve/store locations
generic instead of being limited to PD only.

This breaks the PD-only feature restriction and allows the mechanism
to be reused by other roles/components.

Signed-off-by: Tony Lin <tony.lin@intel.com>

* move retrieve & store locations from storage manger to cache engine

Signed-off-by: Tony Lin <tony.lin@intel.com>

* add para validation check

Signed-off-by: Tony Lin <tony.lin@intel.com>

* config: replace hardcoded IP with placeholder in decoder remote configs

Signed-off-by: Tony Lin <tony.lin@intel.com>

* resolve conflicts and rebase to the latest

Signed-off-by: Tony Lin <tony.lin@intel.com>

* address review comments

Signed-off-by: Tony Lin <tony.lin@intel.com>

* add description in configurations.rst

Signed-off-by: Tony Lin <tony.lin@intel.com>

---------

Signed-off-by: Tony Lin <tony.lin@intel.com>
Co-authored-by: deng451e <57919305+deng451e@users.noreply.github.com>

put installation compatibility table into csv

Signed-off-by: deng451e <838677410@qq.com>

docs: make compat table scrollable

Signed-off-by: deng451e <838677410@qq.com>
deng451e pushed a commit to deng451e/LMCache that referenced this pull request Mar 21, 2026
…nd (LMCache#2509)

* feat(kv_cache): enable asymmetric save/remote storage in PD backend

Remove the restriction that prevented using `save_decode_cache` and
`remote_backend` simultaneously in Prefill-Decode (PD) separation scenarios.

This change introduces `pd_retrieve_locations` and `pd_store_location`
parameters to decouple the KV cache retrieval and storage logic. This
enables an asymmetric cache flow:
1. Prefill nodes transmit KV cache to Decode nodes via the PDBackend.
2. Decode nodes write back their generated KV cache to a remote backend
   for subsequent prefill reuse.
3. In multi-turn dialogue scenarios, subsequent
   prefill requests retrieve historical KV cache from the remote backend,
   significantly increasing Prefix Cache hit rates and reducing TTFT

This decoupling provides greater flexibility for cross-instance cache
management and improves overall pipeline efficiency in distributed
inference.

[ Compute Layer ]
    +----------------------+                 +------------------+
    |  Prefill Node        | ===============>|  Decode Node     |
    | (Hit-Remote & GenKV) |  (1) PDBackend  | (Hit-PD & GenKV) |
    +-------^--------------+                 +-------+----------+
            |                                   |
            :                                   :
------------|-----------------------------------|------------
[ Storage Layer ]                               |
            |                                   | (2) pd_store_location
            | (3) pd_retrieve_locations         |     (Decode -> Pool)
            |     (Pool -> Prefill)             |
            |                                   v
    +-------+--------------------------------------------+
    |             Distributed Storage Pool               |
    |   [Node A]    [Node B]    [Node C]    [Node D]     |
    |   <=======  (Object Storage / NFS / DFS)  =======> |
    +----------------------------------------------------+

Workflow:
1. Prefill -> Decode (PDBackend): Initial KV transfer for the current turn.
2. Decode -> Remote (Store): Decode saves updated KV to NFS for persistence.
3. Remote -> Prefill (Retrieve): Next-turn prefill pulls from Remote,
   drastically increasing Prefix Cache hit rate for multi-turn dialogues.

Signed-off-by: Tony Lin <tony.lin@intel.com>

* small refactor

Signed-off-by: Tony Lin <tony.lin@intel.com>

* config examples for pd + remote backends

Signed-off-by: Tony Lin <tony.lin@intel.com>

* refactor: rename pd_retrieve_locations/pd_store_location to retrieve_locations/store_location

Remove the PD-specific prefix to make the retrieve/store locations
generic instead of being limited to PD only.

This breaks the PD-only feature restriction and allows the mechanism
to be reused by other roles/components.

Signed-off-by: Tony Lin <tony.lin@intel.com>

* move retrieve & store locations from storage manger to cache engine

Signed-off-by: Tony Lin <tony.lin@intel.com>

* add para validation check

Signed-off-by: Tony Lin <tony.lin@intel.com>

* config: replace hardcoded IP with placeholder in decoder remote configs

Signed-off-by: Tony Lin <tony.lin@intel.com>

* resolve conflicts and rebase to the latest

Signed-off-by: Tony Lin <tony.lin@intel.com>

* address review comments

Signed-off-by: Tony Lin <tony.lin@intel.com>

* add description in configurations.rst

Signed-off-by: Tony Lin <tony.lin@intel.com>

---------

Signed-off-by: Tony Lin <tony.lin@intel.com>
Co-authored-by: deng451e <57919305+deng451e@users.noreply.github.com>

put installation compatibility table into csv

Signed-off-by: deng451e <838677410@qq.com>

docs: make compat table scrollable

Signed-off-by: deng451e <838677410@qq.com>
deng451e added a commit to deng451e/LMCache that referenced this pull request Mar 25, 2026
…nd (LMCache#2509)

* feat(kv_cache): enable asymmetric save/remote storage in PD backend

Remove the restriction that prevented using `save_decode_cache` and
`remote_backend` simultaneously in Prefill-Decode (PD) separation scenarios.

This change introduces `pd_retrieve_locations` and `pd_store_location`
parameters to decouple the KV cache retrieval and storage logic. This
enables an asymmetric cache flow:
1. Prefill nodes transmit KV cache to Decode nodes via the PDBackend.
2. Decode nodes write back their generated KV cache to a remote backend
   for subsequent prefill reuse.
3. In multi-turn dialogue scenarios, subsequent
   prefill requests retrieve historical KV cache from the remote backend,
   significantly increasing Prefix Cache hit rates and reducing TTFT

This decoupling provides greater flexibility for cross-instance cache
management and improves overall pipeline efficiency in distributed
inference.

[ Compute Layer ]
    +----------------------+                 +------------------+
    |  Prefill Node        | ===============>|  Decode Node     |
    | (Hit-Remote & GenKV) |  (1) PDBackend  | (Hit-PD & GenKV) |
    +-------^--------------+                 +-------+----------+
            |                                   |
            :                                   :
------------|-----------------------------------|------------
[ Storage Layer ]                               |
            |                                   | (2) pd_store_location
            | (3) pd_retrieve_locations         |     (Decode -> Pool)
            |     (Pool -> Prefill)             |
            |                                   v
    +-------+--------------------------------------------+
    |             Distributed Storage Pool               |
    |   [Node A]    [Node B]    [Node C]    [Node D]     |
    |   <=======  (Object Storage / NFS / DFS)  =======> |
    +----------------------------------------------------+

Workflow:
1. Prefill -> Decode (PDBackend): Initial KV transfer for the current turn.
2. Decode -> Remote (Store): Decode saves updated KV to NFS for persistence.
3. Remote -> Prefill (Retrieve): Next-turn prefill pulls from Remote,
   drastically increasing Prefix Cache hit rate for multi-turn dialogues.

Signed-off-by: Tony Lin <tony.lin@intel.com>

* small refactor

Signed-off-by: Tony Lin <tony.lin@intel.com>

* config examples for pd + remote backends

Signed-off-by: Tony Lin <tony.lin@intel.com>

* refactor: rename pd_retrieve_locations/pd_store_location to retrieve_locations/store_location

Remove the PD-specific prefix to make the retrieve/store locations
generic instead of being limited to PD only.

This breaks the PD-only feature restriction and allows the mechanism
to be reused by other roles/components.

Signed-off-by: Tony Lin <tony.lin@intel.com>

* move retrieve & store locations from storage manger to cache engine

Signed-off-by: Tony Lin <tony.lin@intel.com>

* add para validation check

Signed-off-by: Tony Lin <tony.lin@intel.com>

* config: replace hardcoded IP with placeholder in decoder remote configs

Signed-off-by: Tony Lin <tony.lin@intel.com>

* resolve conflicts and rebase to the latest

Signed-off-by: Tony Lin <tony.lin@intel.com>

* address review comments

Signed-off-by: Tony Lin <tony.lin@intel.com>

* add description in configurations.rst

Signed-off-by: Tony Lin <tony.lin@intel.com>

---------

Signed-off-by: Tony Lin <tony.lin@intel.com>
Co-authored-by: deng451e <57919305+deng451e@users.noreply.github.com>
deng451e added a commit to deng451e/LMCache that referenced this pull request Mar 27, 2026
…nd (LMCache#2509)

* feat(kv_cache): enable asymmetric save/remote storage in PD backend

Remove the restriction that prevented using `save_decode_cache` and
`remote_backend` simultaneously in Prefill-Decode (PD) separation scenarios.

This change introduces `pd_retrieve_locations` and `pd_store_location`
parameters to decouple the KV cache retrieval and storage logic. This
enables an asymmetric cache flow:
1. Prefill nodes transmit KV cache to Decode nodes via the PDBackend.
2. Decode nodes write back their generated KV cache to a remote backend
   for subsequent prefill reuse.
3. In multi-turn dialogue scenarios, subsequent
   prefill requests retrieve historical KV cache from the remote backend,
   significantly increasing Prefix Cache hit rates and reducing TTFT

This decoupling provides greater flexibility for cross-instance cache
management and improves overall pipeline efficiency in distributed
inference.

[ Compute Layer ]
    +----------------------+                 +------------------+
    |  Prefill Node        | ===============>|  Decode Node     |
    | (Hit-Remote & GenKV) |  (1) PDBackend  | (Hit-PD & GenKV) |
    +-------^--------------+                 +-------+----------+
            |                                   |
            :                                   :
------------|-----------------------------------|------------
[ Storage Layer ]                               |
            |                                   | (2) pd_store_location
            | (3) pd_retrieve_locations         |     (Decode -> Pool)
            |     (Pool -> Prefill)             |
            |                                   v
    +-------+--------------------------------------------+
    |             Distributed Storage Pool               |
    |   [Node A]    [Node B]    [Node C]    [Node D]     |
    |   <=======  (Object Storage / NFS / DFS)  =======> |
    +----------------------------------------------------+

Workflow:
1. Prefill -> Decode (PDBackend): Initial KV transfer for the current turn.
2. Decode -> Remote (Store): Decode saves updated KV to NFS for persistence.
3. Remote -> Prefill (Retrieve): Next-turn prefill pulls from Remote,
   drastically increasing Prefix Cache hit rate for multi-turn dialogues.

Signed-off-by: Tony Lin <tony.lin@intel.com>

* small refactor

Signed-off-by: Tony Lin <tony.lin@intel.com>

* config examples for pd + remote backends

Signed-off-by: Tony Lin <tony.lin@intel.com>

* refactor: rename pd_retrieve_locations/pd_store_location to retrieve_locations/store_location

Remove the PD-specific prefix to make the retrieve/store locations
generic instead of being limited to PD only.

This breaks the PD-only feature restriction and allows the mechanism
to be reused by other roles/components.

Signed-off-by: Tony Lin <tony.lin@intel.com>

* move retrieve & store locations from storage manger to cache engine

Signed-off-by: Tony Lin <tony.lin@intel.com>

* add para validation check

Signed-off-by: Tony Lin <tony.lin@intel.com>

* config: replace hardcoded IP with placeholder in decoder remote configs

Signed-off-by: Tony Lin <tony.lin@intel.com>

* resolve conflicts and rebase to the latest

Signed-off-by: Tony Lin <tony.lin@intel.com>

* address review comments

Signed-off-by: Tony Lin <tony.lin@intel.com>

* add description in configurations.rst

Signed-off-by: Tony Lin <tony.lin@intel.com>

---------

Signed-off-by: Tony Lin <tony.lin@intel.com>
Co-authored-by: deng451e <57919305+deng451e@users.noreply.github.com>
jooho-XCENA pushed a commit to xcena-dev/LMCache that referenced this pull request Apr 2, 2026
…nd (LMCache#2509)

* feat(kv_cache): enable asymmetric save/remote storage in PD backend

Remove the restriction that prevented using `save_decode_cache` and
`remote_backend` simultaneously in Prefill-Decode (PD) separation scenarios.

This change introduces `pd_retrieve_locations` and `pd_store_location`
parameters to decouple the KV cache retrieval and storage logic. This
enables an asymmetric cache flow:
1. Prefill nodes transmit KV cache to Decode nodes via the PDBackend.
2. Decode nodes write back their generated KV cache to a remote backend
   for subsequent prefill reuse.
3. In multi-turn dialogue scenarios, subsequent
   prefill requests retrieve historical KV cache from the remote backend,
   significantly increasing Prefix Cache hit rates and reducing TTFT

This decoupling provides greater flexibility for cross-instance cache
management and improves overall pipeline efficiency in distributed
inference.

[ Compute Layer ]
    +----------------------+                 +------------------+
    |  Prefill Node        | ===============>|  Decode Node     |
    | (Hit-Remote & GenKV) |  (1) PDBackend  | (Hit-PD & GenKV) |
    +-------^--------------+                 +-------+----------+
            |                                   |
            :                                   :
------------|-----------------------------------|------------
[ Storage Layer ]                               |
            |                                   | (2) pd_store_location
            | (3) pd_retrieve_locations         |     (Decode -> Pool)
            |     (Pool -> Prefill)             |
            |                                   v
    +-------+--------------------------------------------+
    |             Distributed Storage Pool               |
    |   [Node A]    [Node B]    [Node C]    [Node D]     |
    |   <=======  (Object Storage / NFS / DFS)  =======> |
    +----------------------------------------------------+

Workflow:
1. Prefill -> Decode (PDBackend): Initial KV transfer for the current turn.
2. Decode -> Remote (Store): Decode saves updated KV to NFS for persistence.
3. Remote -> Prefill (Retrieve): Next-turn prefill pulls from Remote,
   drastically increasing Prefix Cache hit rate for multi-turn dialogues.

Signed-off-by: Tony Lin <tony.lin@intel.com>

* small refactor

Signed-off-by: Tony Lin <tony.lin@intel.com>

* config examples for pd + remote backends

Signed-off-by: Tony Lin <tony.lin@intel.com>

* refactor: rename pd_retrieve_locations/pd_store_location to retrieve_locations/store_location

Remove the PD-specific prefix to make the retrieve/store locations
generic instead of being limited to PD only.

This breaks the PD-only feature restriction and allows the mechanism
to be reused by other roles/components.

Signed-off-by: Tony Lin <tony.lin@intel.com>

* move retrieve & store locations from storage manger to cache engine

Signed-off-by: Tony Lin <tony.lin@intel.com>

* add para validation check

Signed-off-by: Tony Lin <tony.lin@intel.com>

* config: replace hardcoded IP with placeholder in decoder remote configs

Signed-off-by: Tony Lin <tony.lin@intel.com>

* resolve conflicts and rebase to the latest

Signed-off-by: Tony Lin <tony.lin@intel.com>

* address review comments

Signed-off-by: Tony Lin <tony.lin@intel.com>

* add description in configurations.rst

Signed-off-by: Tony Lin <tony.lin@intel.com>

---------

Signed-off-by: Tony Lin <tony.lin@intel.com>
Co-authored-by: deng451e <57919305+deng451e@users.noreply.github.com>
jooho-XCENA pushed a commit to xcena-dev/LMCache that referenced this pull request Apr 2, 2026
…nd (LMCache#2509)

* feat(kv_cache): enable asymmetric save/remote storage in PD backend

Remove the restriction that prevented using `save_decode_cache` and
`remote_backend` simultaneously in Prefill-Decode (PD) separation scenarios.

This change introduces `pd_retrieve_locations` and `pd_store_location`
parameters to decouple the KV cache retrieval and storage logic. This
enables an asymmetric cache flow:
1. Prefill nodes transmit KV cache to Decode nodes via the PDBackend.
2. Decode nodes write back their generated KV cache to a remote backend
   for subsequent prefill reuse.
3. In multi-turn dialogue scenarios, subsequent
   prefill requests retrieve historical KV cache from the remote backend,
   significantly increasing Prefix Cache hit rates and reducing TTFT

This decoupling provides greater flexibility for cross-instance cache
management and improves overall pipeline efficiency in distributed
inference.

[ Compute Layer ]
    +----------------------+                 +------------------+
    |  Prefill Node        | ===============>|  Decode Node     |
    | (Hit-Remote & GenKV) |  (1) PDBackend  | (Hit-PD & GenKV) |
    +-------^--------------+                 +-------+----------+
            |                                   |
            :                                   :
------------|-----------------------------------|------------
[ Storage Layer ]                               |
            |                                   | (2) pd_store_location
            | (3) pd_retrieve_locations         |     (Decode -> Pool)
            |     (Pool -> Prefill)             |
            |                                   v
    +-------+--------------------------------------------+
    |             Distributed Storage Pool               |
    |   [Node A]    [Node B]    [Node C]    [Node D]     |
    |   <=======  (Object Storage / NFS / DFS)  =======> |
    +----------------------------------------------------+

Workflow:
1. Prefill -> Decode (PDBackend): Initial KV transfer for the current turn.
2. Decode -> Remote (Store): Decode saves updated KV to NFS for persistence.
3. Remote -> Prefill (Retrieve): Next-turn prefill pulls from Remote,
   drastically increasing Prefix Cache hit rate for multi-turn dialogues.

Signed-off-by: Tony Lin <tony.lin@intel.com>

* small refactor

Signed-off-by: Tony Lin <tony.lin@intel.com>

* config examples for pd + remote backends

Signed-off-by: Tony Lin <tony.lin@intel.com>

* refactor: rename pd_retrieve_locations/pd_store_location to retrieve_locations/store_location

Remove the PD-specific prefix to make the retrieve/store locations
generic instead of being limited to PD only.

This breaks the PD-only feature restriction and allows the mechanism
to be reused by other roles/components.

Signed-off-by: Tony Lin <tony.lin@intel.com>

* move retrieve & store locations from storage manger to cache engine

Signed-off-by: Tony Lin <tony.lin@intel.com>

* add para validation check

Signed-off-by: Tony Lin <tony.lin@intel.com>

* config: replace hardcoded IP with placeholder in decoder remote configs

Signed-off-by: Tony Lin <tony.lin@intel.com>

* resolve conflicts and rebase to the latest

Signed-off-by: Tony Lin <tony.lin@intel.com>

* address review comments

Signed-off-by: Tony Lin <tony.lin@intel.com>

* add description in configurations.rst

Signed-off-by: Tony Lin <tony.lin@intel.com>

---------

Signed-off-by: Tony Lin <tony.lin@intel.com>
Co-authored-by: deng451e <57919305+deng451e@users.noreply.github.com>
@hlin99 hlin99 deleted the PD_save_decode branch April 25, 2026 05:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

full Run comprehensive tests on this PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants