Skip to content

[Correctness]: Avoid overwriting APC overlap#2671

Merged
ApostaC merged 4 commits intoLMCache:devfrom
sammshen:mp-race
Mar 6, 2026
Merged

[Correctness]: Avoid overwriting APC overlap#2671
ApostaC merged 4 commits intoLMCache:devfrom
sammshen:mp-race

Conversation

@sammshen
Copy link
Copy Markdown
Contributor

@sammshen sammshen commented Mar 3, 2026

NEEDS: vllm-project/vllm#35831

Tested with the correctness test in the new experimental CI which caught a bug in MP mode: #2663

running lm_eval` gsm8k (300 samples, 50 concurrent, Qwen3-14B) samples identical across two runs. Previously failed consistently on doc_id 114 and sometimes on doc_id 57) which is now passing with this PR and the vllm side adapter PR

we want to fix CUDA stream race during MP retrieve that caused incorrect outputs under high concurrency. when GetRetrieveMetadata aligns the retrieve start down to the chunk boundary, APC-shared blocks get included in the write target. The LMCache server then writes to these blocks on its own CUDA stream while concurrent requests read them on the vLLM stream.

do this by adding a skip_prefix_n_tokens parameter to the multi-layer KV transfer kernel so the server skips writing to APC-overlapping positions entirely. The kernel handles the offset natively (no Python-side tensor slicing or .contiguous() copy needed which would be the case because T is not the last dimension in 2LTD).

Signed-off-by: Samuel Shen <slshen@uchciago.edu>
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves a critical correctness issue in multi-process environments where a race condition between LMCache server writes and vLLM reads of APC-shared GPU blocks could lead to corrupted data. The solution involves introducing a mechanism to prevent the server from overwriting portions of memory that are actively being used by other processes, thereby ensuring data integrity and stable operation under high concurrency.

Highlights

  • Correctness Fix for CUDA Stream Race: Addressed a bug in multi-process (MP) mode where concurrent requests could lead to incorrect outputs due to a CUDA stream race. This occurred when the LMCache server wrote to APC-shared blocks on its own stream while vLLM concurrently read them.
  • Introduction of skip_prefix_n_tokens: Introduced a new parameter, skip_prefix_n_tokens, to the multi-layer KV transfer kernel. This parameter allows the server to explicitly skip writing to initial token positions that overlap with APC-cached blocks, preventing data corruption.
  • Kernel-level Offset Handling: Implemented the skip_prefix_n_tokens logic directly within the CUDA kernel (load_and_reshape_multi_layer_kernel) to handle the offset natively. This avoids the need for Python-side tensor slicing or .contiguous() copies, which would be inefficient for 2LTD tensors.
  • Protocol and Adapter Updates: Updated the multi-process communication protocol and the vLLM adapter to propagate the skip_first_n_tokens (Python-side equivalent) parameter from the client to the server, ensuring the new functionality is correctly utilized during retrieve operations.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • csrc/mem_kernels.cu
    • Added skip_prefix_n_tokens parameter to load_and_reshape_multi_layer_kernel.
    • Modified slot_idx and lmcache_offset calculations within the kernel to incorporate kv_token_id which accounts for skip_prefix_n_tokens.
    • Updated the multi_layer_kv_transfer_templated function to accept skip_prefix_n_tokens and pass it to the kernel launch.
    • Adjusted num_tokens calculation in multi_layer_kv_transfer_templated to use key_value.size(2) and introduced num_transfer_tokens for grid dimensioning.
  • csrc/mem_kernels.cuh
    • Updated the declaration of multi_layer_kv_transfer to include the new skip_prefix_n_tokens parameter with a default value of 0.
  • csrc/pybind.cpp
    • Modified the Python binding for multi_layer_kv_transfer to expose skip_prefix_n_tokens as an argument with a default value of 0.
  • lmcache/integration/vllm/vllm_multi_process_adapter.py
    • Added skip_first_n_tokens field to the LoadStoreOp dataclass.
    • Included op.skip_first_n_tokens in the payload when submitting a RETRIEVE request.
  • lmcache/v1/multiprocess/protocols/engine.py
    • Updated the RETRIEVE protocol definition to include an integer type for skip_first_n_tokens in its payload_classes.
  • lmcache/v1/multiprocess/server.py
    • Added skip_first_n_tokens parameter with a default value of 0 to the retrieve function signature.
    • Updated the docstring for the retrieve function to describe skip_first_n_tokens.
    • Implemented logic within _retrieve_loop to calculate effective_start and skip_in_chunk based on skip_first_n_tokens.
    • Modified the _retrieve_loop to pass skip_in_chunk to the multi_layer_kv_transfer call.
  • tests/v1/multiprocess/test_blend_server.py
    • Updated the RETRIEVE request call in test_cb_store_final_then_normal_lookup_retrieve to include the new skip_first_n_tokens parameter with a value of 0.
  • tests/v1/multiprocess/test_cache_server.py
    • Updated the RETRIEVE request call in retrieve_keys to include the new skip_first_n_tokens parameter with a value of 0.
  • tests/v1/multiprocess/test_mq.py
    • Updated the RETRIEVE request call in test_mq_retrieve to include the new skip_first_n_tokens parameter with a value of 0.
  • tests/v1/multiprocess/test_mq_handler_helpers.py
    • Added skip_first_n_tokens parameter with a default value of 0 to the retrieve_handler function.
    • Updated the docstring for retrieve_handler to describe skip_first_n_tokens.
    • Added an assertion to retrieve_handler to validate the type of skip_first_n_tokens.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Signed-off-by: Samuel Shen <slshen@uchciago.edu>
Comment thread csrc/mem_kernels.cu
const int num_threads = blockDim.x;

const int64_t slot_idx = slot_mapping[token_id];
const int kv_token_id = token_id + skip_prefix_n_tokens;
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we add here to avoid having to slice the slot_mapping

Comment thread csrc/mem_kernels.cu
for (int i = tid; i < scalars_per_token; i += num_threads) {
const int64_t lmcache_offset =
key_value_offset(k_or_v, layer_id, token_id, i, scalars_per_token,
key_value_offset(k_or_v, layer_id, kv_token_id, i, scalars_per_token,
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we do this to avoid ahving to slice the LMCache side memory object (2LTD cannot be contiguously sliced along T dimension)

Comment thread csrc/mem_kernels.cu

int num_layers = key_value.size(1);
int num_tokens = slot_mapping.size(0);
int num_tokens = key_value.size(2);
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's a potentially interesting semantic discussion about whether key_value.size(2) or slot_mapping.size(0) should be used in the future as the transfer size (in tokens) but for now they are the same

Comment thread csrc/mem_kernels.cu
int k_or_v_size = lmc::is_mla(gpu_kv_format) ? 1 : 2;

dim3 grid(key_value.size(2), num_layers, k_or_v_size);
dim3 grid(num_transfer_tokens, num_layers, k_or_v_size);
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we launch fewer blocks in the grid :)

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a skip_prefix_n_tokens parameter to the multi-layer KV transfer kernel and related functions. This change effectively addresses a CUDA stream race condition by allowing the server to skip writing to APC-overlapping positions, which is crucial for correctness in multi-process retrieve operations. The parameter is consistently propagated through the C++ kernels, Python bindings, and Python server/adapter logic, including necessary updates to test calls. A notable improvement was also made in csrc/mem_kernels.cu to correctly derive num_tokens from the key_value tensor's dimensions.

Note: Security Review is unavailable for this PR.

Comment thread csrc/mem_kernels.cu

int num_layers = key_value.size(1);
int num_tokens = slot_mapping.size(0);
int num_tokens = key_value.size(2);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The change to derive num_tokens from key_value.size(2) instead of slot_mapping.size(0) is a significant correctness improvement. The key_value tensor's third dimension (key_value.size(2)) accurately reflects the number of tokens it contains, ensuring that the kernel operates on the correct data range. This is especially important with the introduction of skip_prefix_n_tokens to prevent potential out-of-bounds memory access.

@sammshen sammshen requested a review from ApostaC March 5, 2026 22:04
Copy link
Copy Markdown
Contributor

@ApostaC ApostaC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably need to assert skip_first_n_tokens < self.chunk_size (we can put it into adapter as well)

Samuel Shen and others added 2 commits March 5, 2026 17:47
Signed-off-by: Samuel Shen <slshen@uchciago.edu>
@sammshen sammshen added the full Run comprehensive tests on this PR label Mar 6, 2026
@ApostaC ApostaC enabled auto-merge (squash) March 6, 2026 02:52
@ApostaC ApostaC merged commit 38dddf9 into LMCache:dev Mar 6, 2026
27 of 29 checks passed
mauryaavinash95 pushed a commit to mauryaavinash95/LMCache that referenced this pull request Mar 7, 2026
Signed-off-by: Samuel Shen <slshen@uchciago.edu>
shaoxiawjc pushed a commit to shaoxiawjc/LMCache that referenced this pull request Mar 11, 2026
Signed-off-by: Samuel Shen <slshen@uchciago.edu>
Signed-off-by: shaoxiawjc <wjc2800@163.com>
realAaronWu pushed a commit to realAaronWu/LMCache that referenced this pull request Mar 20, 2026
Signed-off-by: Samuel Shen <slshen@uchciago.edu>
Signed-off-by: Aaron Wu <aaron.wu@dell.com>
jooho-XCENA pushed a commit to xcena-dev/LMCache that referenced this pull request Apr 2, 2026
Signed-off-by: Samuel Shen <slshen@uchciago.edu>
jooho-XCENA pushed a commit to xcena-dev/LMCache that referenced this pull request Apr 2, 2026
Signed-off-by: Samuel Shen <slshen@uchciago.edu>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

full Run comprehensive tests on this PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants