Skip to content

[Edge Case]: Preemption Loading#2007

Merged
maobaolong merged 14 commits intoLMCache:devfrom
sammshen:localdev/preemption-loading
Nov 21, 2025
Merged

[Edge Case]: Preemption Loading#2007
maobaolong merged 14 commits intoLMCache:devfrom
sammshen:localdev/preemption-loading

Conversation

@sammshen
Copy link
Copy Markdown
Contributor

@sammshen sammshen commented Nov 16, 2025

FIX #1969 #1361

This is the script I used to test preemption minimal case.

"""
Intricacy of the vLLM KV Cache Management: 
- Usually 3 blocks cannot be used
- You should generally add 1 to your (# input tokens) + (# output tokens)

Server deployment: 
LMCACHE_CHUNK_SIZE=8 \
vllm serve meta-llama/Llama-3.1-8B-Instruct --num-gpu-blocks 10 --block-size 16 \
  --no-enable-prefix-caching \
  --kv-transfer-config \
    '{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_both"}' \
    > debug.log 2>&1

vllm serve meta-llama/Llama-3.1-8B-Instruct --num-gpu-blocks 10 --block-size 16 \
  --no-enable-prefix-caching \
    > debug.log 2>&1

We want both requests to saturate all the available blocks. This will force request 0 to to preempt request 1.
"""

import asyncio
from openai import AsyncOpenAI

client = AsyncOpenAI(
    base_url="http://localhost:8000/v1",
    api_key="none",
)

async def run_request(text: str):
    num_gpu_blocks = 10
    # n = num_gpu_blocks - 3
    # 5 + 16 * (n - 1)
    max_output_tokens = 5 + 16 * ((num_gpu_blocks - 3) - 1)
    resp = await client.chat.completions.create(
        model="meta-llama/Llama-3.1-8B-Instruct",
        messages=[{"role": "user", "content": text}],
        max_tokens=max_output_tokens,
    )
    return resp.choices[0].message.content

async def main():
    """
    "1Tell me a long story that never ends" and 
        "1Tell me a long story that never ends" 
        "2Tell me a long story that never ends" 
        are 10 tokens
    """
    prompts = [
        "1Tell me a long story that never ends",
        "2Tell me a long story that never ends",
    ]

    # run both requests at the same time
    results = await asyncio.gather(*(run_request(p) for p in prompts))

    for i, r in enumerate(results, 1):
        print(r)

if __name__ == "__main__":
    asyncio.run(main())

Logs before the fix (we see the tokens but don't load them):

(EngineCore_DP0 pid=3322770) [2025-11-18 19:21:06,890] LMCache INFO: Reqid: chatcmpl-75ed50a6af244d9981b15189a425de65, Total tokens 65, LMCache hit tokens: 44, need to load: 44 (vllm_v1_adapter.py:1330:lmcache.integration.vllm.vllm_v1_adapter)
(APIServer pid=3322482) INFO:     127.0.0.1:51520 - "POST /v1/chat/completions HTTP/1.1" 200 OK
(APIServer pid=3322482) INFO:     127.0.0.1:51528 - "POST /v1/chat/completions HTTP/1.1" 200 OK

Logs after the fix:

�[1;36m(EngineCore_DP0 pid=1538843)�[0;0m INFO 11-17 02:55:14 [scheduler.py:428] Attempting to recover preempted request chatcmpl-0da54347fb1a4053b1b373a33ce88915
�[1;36m(EngineCore_DP0 pid=1538843)�[0;0m �[32;20m[2025-11-17 02:55:14,985] LMCache INFO:�[0m Reqid: chatcmpl-0da54347fb1a4053b1b373a33ce88915, Total tokens 65, LMCache hit tokens: 40, need to load: 40 �[3m(vllm_v1_adapter.py:1354:lmcache.integration.vllm.vllm_v1_adapter)�[0m
�[1;36m(EngineCore_DP0 pid=1708755)�[0;0m �[32;20m[2025-11-17 07:23:12,194] LMCache INFO:�[0m Request chatcmpl-b492c116d42a45e2bf5b8df13c501949 lmcache cached tokens: 40, vllm cached tokens: 0, can load: True �[3m(vllm_v1_adapter.py:1579:lmcache.integration.vllm.vllm_v1_adapter)�[0m
�[1;36m(EngineCore_DP0 pid=1708755)�[0;0m �[32;20m[2025-11-17 07:23:12,195] LMCache INFO:�[0m Retrieved 40 out of 40 required tokens (from 40 total tokens). size: 0.0049 gb, cost 0.4515 ms, throughput: 10.8152 GB/s; �[3m(cache_engine.py:560:lmcache.v1.cache_engine)�[0m

Signed-off-by: Samuel Shen <slshen@uchciago.edu>
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @sammshen, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the KV cache loading mechanism by integrating robust support for preempted requests. It refactors how request states are managed upon resumption, ensuring that the entire token sequence is considered for cache lookups, and optimizes the lookup process with a new preliminary cache check. These changes are vital for maintaining correctness and efficiency when interacting with vLLM's preemption capabilities, allowing for smoother recovery and continued processing of interrupted requests.

Highlights

  • Preemption Handling: Introduced specific logic within RequestTracker.update to correctly manage requests that have been preempted, including resetting token IDs and re-assigning block IDs to ensure proper state recovery.
  • Cached Lookup Optimization: Added a new lookup_cache method to the LookupClientInterface and its implementations, enabling a quick check for cached lookup results before performing a more intensive full lookup, improving efficiency.
  • Comprehensive Token ID Usage: Changed the source of token IDs for KV cache lookup from request.prompt_token_ids to request.all_token_ids to accurately account for the full sequence of tokens in preemption scenarios.
  • Refined Assertion Logic: Updated the assertion in update_state_after_alloc to correctly account for the last token when a full LMCache hit occurs, ensuring that logits are appropriately recalculated.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant changes to handle request preemption, primarily by making the logic in get_num_new_matched_tokens idempotent and updating the RequestTracker state for preempted requests. The changes are generally well-structured, especially the refactoring of the lookup client to support idempotency. However, I've identified a critical issue in how the RequestTracker's token history is updated for preempted requests, which could lead to incorrect caching behavior. Additionally, there's an opportunity to improve code maintainability by refactoring some duplicated logic. My detailed feedback is in the comments below.

Comment on lines +250 to +258
if preempted:
# the block ids will change after preemption
self.allocated_block_ids = new_block_ids
# reset the number of saved tokens
self.num_saved_tokens = lmcache_cached_tokens

self.token_ids = []
else:
self.allocated_block_ids.extend(new_block_ids)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

When a request is preempted, self.token_ids is reset to an empty list. It is then extended with new_token_ids, which, based on its calculation in build_connector_meta, is only a slice of newly scheduled tokens. This results in the tracker losing the history of previously processed tokens for the preempted request, which can lead to an inconsistent state and incorrect caching behavior.

If vLLM rewinds the request upon preemption, token_ids should be truncated to the correct length, not completely reset. If there is no rewind, simply appending new_token_ids should be sufficient. The current implementation appears to be a bug. Removing the reset of token_ids would be a safer approach, although handling potential rewinds might require more context.

        if preempted:
            # the block ids will change after preemption
            self.allocated_block_ids = new_block_ids
            # reset the number of saved tokens
            self.num_saved_tokens = lmcache_cached_tokens
        else:
            self.allocated_block_ids.extend(new_block_ids)

Comment on lines +1558 to +1561
load_spec = self.load_specs.pop(req_id, None)
lmcache_cached_tokens = 0
if load_spec is not None:
lmcache_cached_tokens = load_spec.lmcache_cached_tokens
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of code to get load_spec and lmcache_cached_tokens is duplicated from lines 1519-1522. To improve maintainability and reduce redundancy, consider extracting this logic into a private helper method. For example:

def _get_lmcache_tokens_from_load_spec(self, req_id: str) -> tuple[Optional[LoadSpec], int]:
    load_spec = self.load_specs.pop(req_id, None)
    lmcache_cached_tokens = 0
    if load_spec is not None:
        lmcache_cached_tokens = load_spec.lmcache_cached_tokens
    return load_spec, lmcache_cached_tokens

You could then call this helper in both places.

Samuel Shen added 4 commits November 16, 2025 03:36
Signed-off-by: Samuel Shen <slshen@uchciago.edu>
Signed-off-by: Samuel Shen <slshen@uchciago.edu>
Signed-off-by: Samuel Shen <slshen@uchciago.edu>
class LookupClientInterface(metaclass=abc.ABCMeta):
"""Abstract interface for lookup clients."""

def lookup_cache(self, lookup_id: str) -> Optional[int]:
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the default does nothing

f"{self.load_specs[request.request_id].lmcache_cached_tokens} - "
f"{self.load_specs[request.request_id].vllm_cached_tokens}"
f" for request {request.request_id}"
recalc_last = (
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the logic is the same. we also check for prompt hit case with the recalc_last

self._request_trackers.pop(finished_req_id, None)
self._unfinished_requests.pop(finished_req_id, None)

# We should load KV for:
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated comment

Signed-off-by: Samuel Shen <slshen@uchciago.edu>
Comment thread lmcache/integration/vllm/vllm_v1_adapter.py
self.allocated_block_ids = new_block_ids
# reset the number of saved tokens
self.num_saved_tokens = lmcache_cached_tokens
# we don't need to extend the token ids in the preempted case
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because there is no previous step where we generated an extra token since the running request was preempted

the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
"""
# to handle preempted requests, we want `get_num_new_matched_tokens` to be
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the main logic for preemption

# TODO: this is a dangerous reference to the request object inside vllm
if request := self._unfinished_requests.get(req_id):
num_current_tokens = len(request_tracker.token_ids)
num_current_tokens = request.num_computed_tokens
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we use len(request_tracker.token_ids), this is a serious correctness bug

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why

for i, req_id in enumerate(cached_reqs.req_ids):
request_tracker = self._request_trackers[req_id]
num_new_tokens = scheduler_output.num_scheduled_tokens[req_id]
# TODO: this is a dangerous reference to the request object inside vllm
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can think of a better way to handle this in the future

@sammshen sammshen changed the title [WIP]: Preemption Loading [Edge Case]: Preemption Loading Nov 17, 2025

is_last_prefill = False
if input_token_len == tracker.prompt_len:
if input_token_len >= tracker.prompt_len:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a related bug or a separate one?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the same bug. the input tokens may be larger than the prompt_len when a preempted request is recovered.

Copy link
Copy Markdown
Contributor

@yoo-kumaneko yoo-kumaneko left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

req_id = request.request_id

# consult the cache before any processing
if cached_num_hit_toks := self.lookup_client.lookup_cache(lookup_id=req_id):
Copy link
Copy Markdown
Collaborator

@maobaolong maobaolong Nov 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we config to initialize a HitLimitLookupClient instance, it would be wrong.
Because the two above did not implement the new api you added. @sammshen

Copy link
Copy Markdown
Collaborator

@maobaolong maobaolong Nov 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe #2021 this can help to make sure we have implemented all the method without forget any one.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sammshen You can add the following code to the tail of lmcache/v1/lookup_client/hit_limit_lookup_client.py

    def lookup_cache(self, lookup_id: str) -> Optional[int]:
        return self.actual_lookup_client.lookup_cache(lookup_id)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the default implementation will cover it right? the default implementation is to return None in lmcache/v1/lookup_client/abstract_client.py

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @maobaolong I did not read the hit limit client carefully enough before, but updated now!

maobaolong added a commit to maobaolong/LMCache that referenced this pull request Nov 19, 2025
Signed-off-by: baoloongmao <baoloongmao@tencent.com>
maobaolong added a commit to maobaolong/LMCache that referenced this pull request Nov 19, 2025
Signed-off-by: baoloongmao <baoloongmao@tencent.com>
@sammshen
Copy link
Copy Markdown
Contributor Author

@maobaolong @chunxiaozheng feel free to take a look again

Samuel Shen added 2 commits November 20, 2025 06:31
Signed-off-by: Samuel Shen <slshen@uchciago.edu>
@maobaolong
Copy link
Copy Markdown
Collaborator

@sammshen Sorry, as ChunkStatisticsLookupClient merged to dev just today, so you have to implement the following method into it too.

def lookup_cache(self, lookup_id: str) -> Optional[int]:
    return self.actual_lookup_client.lookup_cache(lookup_id)

@maobaolong
Copy link
Copy Markdown
Collaborator

@sammshen I pressed the update branch button, so the UT check will failed.

Signed-off-by: Samuel Shen <slshen@uchciago.edu>
@sammshen
Copy link
Copy Markdown
Contributor Author

The UT is very useful! updated @maobaolong

Copy link
Copy Markdown
Contributor

@ApostaC ApostaC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM as we discussed offline.

Copy link
Copy Markdown
Collaborator

@maobaolong maobaolong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@maobaolong maobaolong enabled auto-merge (squash) November 21, 2025 01:04
@github-actions github-actions Bot added the full Run comprehensive tests on this PR label Nov 21, 2025
@ApostaC
Copy link
Copy Markdown
Contributor

ApostaC commented Nov 21, 2025

@sammshen Do we want to create a special comprehensive test case if we know how to reproduce the preemption case?

@maobaolong
Copy link
Copy Markdown
Collaborator

@sammshen Do we want to create a special comprehensive test case if we know how to reproduce the preemption case?

This is a great idea!

@maobaolong maobaolong merged commit cf73d9e into LMCache:dev Nov 21, 2025
21 checks passed
penghb2025 pushed a commit to penghb2025/LMCache that referenced this pull request Dec 15, 2025
* initial commit

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* remove abstract default implementation

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* change naming of preempted cached requests

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* renaming

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* finally working

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* add lookup_cache to hit limit lookup client

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* add lookup_cache to ChunkStatisticsLookupClient

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

---------

Signed-off-by: Samuel Shen <slshen@uchciago.edu>
Co-authored-by: Samuel Shen <slshen@uchciago.edu>
Co-authored-by: maobaolong <baoloongmao@tencent.com>
DongDongJu pushed a commit to DongDongJu/LMCache that referenced this pull request Feb 22, 2026
* initial commit

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* remove abstract default implementation

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* change naming of preempted cached requests

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* renaming

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* finally working

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* add lookup_cache to hit limit lookup client

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* add lookup_cache to ChunkStatisticsLookupClient

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

---------

Signed-off-by: Samuel Shen <slshen@uchciago.edu>
Co-authored-by: Samuel Shen <slshen@uchciago.edu>
Co-authored-by: maobaolong <baoloongmao@tencent.com>
sammshen added a commit to sammshen/LMCache that referenced this pull request Mar 1, 2026
* initial commit

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* remove abstract default implementation

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* change naming of preempted cached requests

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* renaming

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* finally working

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* add lookup_cache to hit limit lookup client

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* add lookup_cache to ChunkStatisticsLookupClient

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

---------

Signed-off-by: Samuel Shen <slshen@uchciago.edu>
Co-authored-by: Samuel Shen <slshen@uchciago.edu>
Co-authored-by: maobaolong <baoloongmao@tencent.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

full Run comprehensive tests on this PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] cache is not retrieved when rescheduling a preempted request in vLLM

5 participants