Skip to content

[Bug] cache is not retrieved when rescheduling a preempted request in vLLM #1969

@ziruiliu

Description

@ziruiliu

To Reproduce
Test script to reproduce this issue is attached at the bottom.
To run this script, activate the venv of vLLM + LMCache and run with cpu backend by default:
PYTHONHASHSEED=0 VLLM_LOGGING_LEVEL=DEBUG LMCACHE_LOG_LEVEL=DEBUG pytest test_engine_lmcache_preemption.py -k preemption -s
You can also run this script with other backends by setting LMCACHE_CONFIG_FILE.

How it works
The script simply uses a small model - Qwen/Qwen3-0.6B-FP8 to generate response. BLOCK_SIZE is set to 16, aligning vllm's block size and lmcache's chunk size. PROMPT_TOKEN_COUNT is set to 2 BLOCKs plus 2 = 34 tokens, which essentially needs 3 blocks. MAX_NEW_TOKENS is set to 2 BLOCKs plus 3 = 35 tokens. All these parameters could be tuned in your environment.

There will be 2 requests running at the same time, so NUM_BLOCKS is set to 7, which limited vLLM's prefix cache size and guarantee the preemption. Set NUM_BLOCKS to greather than 10, no preemption will happen.

I ran this script with additional trace in vLLM and LMCache, found the bug happens as following:

  • request 0: allocating blocks [1,2,3] to store prompt token cache around line 510 in vllm/schedule.py
new_blocks = self.kv_cache_manager.allocate_slots()`
  • request 1: allocating blocks [4,5,6] to store prompt token cache too
  • data computation and stored to [1,2,3] and [4,5,6], respectively
  • request 0 is generating more tokens and related cache data, filling up block pool
  • Because total cache size is 7, so preemption starts
  • request 1 is preempted. Log shows that blocks [4,5,6] are freed, ref_cnt is down to 0 but data remains
  • request 0 keeps being scheduled, block [6] is allocated for request 0 to store cache, around line 263 in vllm/scheduler.py
  • request 0's prompt and new tokens reaches 64, then block [5] is allocated for request 0
  • request 0 finished, and free blocks [1,2,3,6,5]
  • Now blocks[6] and [5] are filled with request 0's data
  • request 1 is rescheduled, and at line 411 in scheduler.py, block [4] is found to be reused because block hash matches
self.kv_cache_manager.get_computed_blocks(request)

At this point, LMCache is able to calculate the number of tokens to be load

LMCache INFO: Reqid: 1, Total tokens 49, LMCache hit tokens: 34, need to load: 18

Message shows that, vllm has req 1's first 16 tokens in block [4], althought the block is freed
LMCache has the entire cache of 34 computed tokens, first 16 are the same as block [4] in vllm,
rest 18 tokens have to be loaded from LMCache to GPU

  • request 1 allocate blocks [5,6,3] for future token generation
    At this point, cache data of the 18 tokens in LMCache are expected to be loaded to block [5] and [6], but it never happens
  • At line 650 in vllm/scheduler.py, build_connector_meta() is called
meta = self.connector.build_connector_meta(scheduler_output)
  • However, in LMCache code, line 1509 is hit in build_connector_meta()'s cached_reqs code branch, with load_spec=None
            req_meta = ReqMeta.from_request_tracker(
                request_tracker,
                self._block_size,
                self._lmcache_chunk_size,
                load_spec=None,
                discard_partial_chunks=self._discard_partial_chunks,
                save_decode_cache=self._save_decode_cache,
            )
  • Retrieval is NOT scheduled because load_spec=None here. Blocks [5] and [6] are NOT filled with expected data from LMCache, and blocks [5,6] still contain stale data from request 0
  • vllm then generate faulty response for request 1.

Example of prompt and expected response
The script is also attached at the bottom. Run the example, it shows the correct response of request 1

$ PYTHONHASHSEED=0 VLLM_LOGGING_LEVEL=DEBUG python3 ./example.py --gpu-memory-utilization=0.1 --enforce-eager --no-enable-prefix-caching --max_model_len=400
...
Request 1
  Prompt token IDs: [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
  Candidate 0 token IDs: [14582, 25, 16246, 264, 25, 715, 2, 9645, 264, 25, 715, 2, 256, 264, 729, 429, 4990, 264, 264, 1140, 315, 25780, 323, 4675, 279, 3082, 315, 279, 264, 21495, 448, 279, 2701, 4682, 510]

Run the test script with NUM_BLOCKS = 7 to force preemption, the WRONG response is like:

$ PYTHONHASHSEED=0 VLLM_LOGGING_LEVEL=DEBUG LMCACHE_LOG_LEVEL=DEBUG LMCACHE_TRACK_USAGE=false pytest test_engine_lmcache_preemption.py -k preemption -s
...
FINISHED REQ 1: [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
FINISHED REQ 1: [14582, 25, 16246, 264, 25, 715, 2, 9645, 264, 25, 715, 2, 256, 264, 729, 311, 1477, 279, 279, 4226, 311, 279, 3405, 3118, 389, 279, 2661, 1946, 198, 2, 323, 279, 3405, 30, 715]

Then run the test script, with NUM_BLOCKS = 11 to AVOID preemption, the response is like:

FINISHED REQ 1: [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
FINISHED REQ 1: [14582, 25, 16246, 264, 25, 715, 2, 9645, 264, 25, 715, 2, 256, 264, 729, 429, 4990, 264, 264, 1140, 315, 25780, 323, 4675, 279, 3082, 315, 279, 264, 21495, 448, 279, 2701, 4682, 510]

which is the same as result of example

I am not sure if this is exactly the issue mentiond in PR #1361 . But I found that the changes in PR #1361 fixes this issue by setting load_spec correctly.

The Script
Test script that force preemption

# test_engine_lmcache_preemption.py 
import time
from typing import Iterable

import pytest

from vllm.config import KVTransferConfig
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams
from vllm.utils.torch_utils import set_default_torch_num_threads
from vllm.v1.engine import EngineCoreEventType, EngineCoreRequest
from vllm.v1.engine.core import EngineCore
from vllm.v1.executor.abstract import Executor

MODEL_NAME = "Qwen/Qwen3-0.6B-FP8"
BLOCK_SIZE = 16
NUM_BLOCKS = 7
PROMPT_TOKEN_COUNT = BLOCK_SIZE * 2 + 2
MAX_NEW_TOKENS = BLOCK_SIZE * 2 + 3

pytestmark = pytest.mark.skipif(  # type: ignore[var-annotated]
    not current_platform.is_cuda(),
    reason="LMCache preemption test requires CUDA support.",
)

def _iter_events(outputs: Iterable):
    for output in outputs:
        if not output.events:
            continue
        for event in output.events:
            yield event


def test_scheduler_handles_preemption_with_lmcache(monkeypatch: pytest.MonkeyPatch):
    """Launch a minimal vLLM engine with LMCache and trigger preemption.

    The test configures the engine to use the LMCache connector via the
    ``--kv-transfer-config`` equivalent API and constrains the number of GPU
    KV blocks. When multiple requests are executed concurrently the scheduler
    must preempt one of them in order to allocate a fresh block. We assert that
    a PREEMPTED event is observed in the engine outputs, ensuring that the
    LMCache-backed scheduler path handles preemption correctly.
    """

    pytest.importorskip("lmcache")

    from lmcache.integration.vllm.utils import ENGINE_NAME
    from lmcache.v1.cache_engine import LMCacheEngineBuilder

    env_vars = {
        "LMCACHE_USE_EXPERIMENTAL": "True",
        "LMCACHE_CHUNK_SIZE": str(BLOCK_SIZE),
        "LMCACHE_LOCAL_CPU": "True",
        "LMCACHE_MAX_LOCAL_CPU_SIZE": "1.0",
    }
    for key, value in env_vars.items():
        monkeypatch.setenv(key, value)

    kv_transfer_config = KVTransferConfig(
        kv_connector="LMCacheConnectorV1",
        kv_role="kv_both",
    )

    engine_args = EngineArgs(
        model=MODEL_NAME,
        max_model_len=BLOCK_SIZE * NUM_BLOCKS,
        block_size=BLOCK_SIZE,
        num_gpu_blocks_override=NUM_BLOCKS,
        kv_transfer_config=kv_transfer_config,
        enforce_eager=True,
        gpu_memory_utilization=0.1,
    )

    vllm_config = engine_args.create_engine_config()
    executor_cls = Executor.get_class(vllm_config)

    engine_core: EngineCore | None = None
    with set_default_torch_num_threads(1):
        engine_core = EngineCore(
            vllm_config=vllm_config,
            executor_class=executor_cls,
            log_stats=True,
        )

    def make_request(idx: int) -> EngineCoreRequest:
        prompt_tokens = [idx + 1] * PROMPT_TOKEN_COUNT
        sampling_params = SamplingParams(
            temperature=0.0,
            max_tokens=MAX_NEW_TOKENS,
        )
        return EngineCoreRequest(
            request_id=str(idx),
            prompt_token_ids=prompt_tokens,
            mm_features=None,
            sampling_params=sampling_params,
            pooling_params=None,
            eos_token_id=None,
            arrival_time=time.time(),
            lora_request=None,
            cache_salt=None,
            data_parallel_rank=None,
        )

    requests = [make_request(0), make_request(1)]

    prompt_token_ids = [requests[0].prompt_token_ids, requests[1].prompt_token_ids]
    output_token_ids = [[], []]

    try:
        for req in requests:
            engine_core.add_request(*engine_core.preprocess_add_request(req))

        seen_preemption = False
        finished = set()

        for _ in range(128):
            outputs_by_engine, model_executed = engine_core.step()

            if not outputs_by_engine:
                print("No outputs yet.")
                break

            engine_outputs = outputs_by_engine[0]

            for out in engine_outputs.outputs:
                output_token_ids[int(out.request_id)].extend(out.new_token_ids)
                if out.finished:
                    print(f"FINISHED REQ {out.request_id}: {prompt_token_ids[int(out.request_id)]}")
                    print(f"FINISHED REQ {out.request_id}: {output_token_ids[int(out.request_id)]}")
                    finished.add(out.request_id)

            for event in _iter_events(engine_outputs.outputs):
                print(f"Event: {event}")
                if event.type is EngineCoreEventType.PREEMPTED:
                    seen_preemption = True

            engine_core.post_step(model_executed)

            if len(finished) == len(requests):
                print("All requests finished.")
                break

        assert finished == {req.request_id for req in requests}
        assert seen_preemption, "Expected at least one preemption event."\
            " Check KV block configuration if this fails."

    finally:
        if engine_core is not None:
            engine_core.shutdown()
        LMCacheEngineBuilder.destroy(ENGINE_NAME)

and the example of promt and expected repsonse
run it as

PYTHONHASHSEED=0 VLLM_LOGGING_LEVEL=DEBUG python3 ./example.py --gpu-memory-utilization=0.1 --enforce-eager --no-enable-prefix-caching --max_model_len=400

script:

# Example.py

from __future__ import annotations
from typing import Any
from vllm import EngineArgs, LLM
from vllm.inputs import TokensPrompt
from vllm.sampling_params import SamplingParams
from vllm.utils.argparse_utils import FlexibleArgumentParser

MODEL_NAME = "Qwen/Qwen3-0.6B-FP8"
BLOCK_SIZE = 16
PROMPT_TOKEN_COUNT = BLOCK_SIZE * 2 + 2
MAX_NEW_TOKENS = BLOCK_SIZE * 2 + 3


def create_parser() -> FlexibleArgumentParser:
    parser = FlexibleArgumentParser()
    EngineArgs.add_cli_args(parser)
    parser.set_defaults(model=MODEL_NAME)
    return parser


def build_prompts() -> list[TokensPrompt]:
    return [
        TokensPrompt(prompt_token_ids=[idx + 1] * PROMPT_TOKEN_COUNT)
        for idx in range(2)
    ]


def main(args: dict[str, Any]) -> None:
    llm = LLM(**args)

    sampling_params = SamplingParams(
        temperature=0.0,
        max_tokens=MAX_NEW_TOKENS,
    )

    prompts = build_prompts()
    outputs = llm.generate(prompts, sampling_params)

    print("-" * 80)
    for request_idx, output in enumerate(outputs):
        print(f"Request {request_idx}")
        print(f"  Prompt token IDs: {output.prompt_token_ids}")

        for candidate_idx, candidate in enumerate(output.outputs):
            print(
                f"  Candidate {candidate_idx} token IDs: {candidate.token_ids}"
            )
            print(f"  Candidate {candidate_idx} text: {candidate.text!r}")
        print("-" * 80)


if __name__ == "__main__":
    parser = create_parser()
    parsed_args: dict[str, Any] = vars(parser.parse_args())
    main(parsed_args)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions