[Correctness]: Avoid overwriting APC overlap#2671
Conversation
Signed-off-by: Samuel Shen <slshen@uchciago.edu>
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request resolves a critical correctness issue in multi-process environments where a race condition between LMCache server writes and vLLM reads of APC-shared GPU blocks could lead to corrupted data. The solution involves introducing a mechanism to prevent the server from overwriting portions of memory that are actively being used by other processes, thereby ensuring data integrity and stable operation under high concurrency. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
| const int num_threads = blockDim.x; | ||
|
|
||
| const int64_t slot_idx = slot_mapping[token_id]; | ||
| const int kv_token_id = token_id + skip_prefix_n_tokens; |
There was a problem hiding this comment.
we add here to avoid having to slice the slot_mapping
| for (int i = tid; i < scalars_per_token; i += num_threads) { | ||
| const int64_t lmcache_offset = | ||
| key_value_offset(k_or_v, layer_id, token_id, i, scalars_per_token, | ||
| key_value_offset(k_or_v, layer_id, kv_token_id, i, scalars_per_token, |
There was a problem hiding this comment.
we do this to avoid ahving to slice the LMCache side memory object (2LTD cannot be contiguously sliced along T dimension)
|
|
||
| int num_layers = key_value.size(1); | ||
| int num_tokens = slot_mapping.size(0); | ||
| int num_tokens = key_value.size(2); |
There was a problem hiding this comment.
there's a potentially interesting semantic discussion about whether key_value.size(2) or slot_mapping.size(0) should be used in the future as the transfer size (in tokens) but for now they are the same
| int k_or_v_size = lmc::is_mla(gpu_kv_format) ? 1 : 2; | ||
|
|
||
| dim3 grid(key_value.size(2), num_layers, k_or_v_size); | ||
| dim3 grid(num_transfer_tokens, num_layers, k_or_v_size); |
There was a problem hiding this comment.
we launch fewer blocks in the grid :)
There was a problem hiding this comment.
Code Review
This pull request introduces a skip_prefix_n_tokens parameter to the multi-layer KV transfer kernel and related functions. This change effectively addresses a CUDA stream race condition by allowing the server to skip writing to APC-overlapping positions, which is crucial for correctness in multi-process retrieve operations. The parameter is consistently propagated through the C++ kernels, Python bindings, and Python server/adapter logic, including necessary updates to test calls. A notable improvement was also made in csrc/mem_kernels.cu to correctly derive num_tokens from the key_value tensor's dimensions.
Note: Security Review is unavailable for this PR.
|
|
||
| int num_layers = key_value.size(1); | ||
| int num_tokens = slot_mapping.size(0); | ||
| int num_tokens = key_value.size(2); |
There was a problem hiding this comment.
The change to derive num_tokens from key_value.size(2) instead of slot_mapping.size(0) is a significant correctness improvement. The key_value tensor's third dimension (key_value.size(2)) accurately reflects the number of tokens it contains, ensuring that the kernel operates on the correct data range. This is especially important with the introduction of skip_prefix_n_tokens to prevent potential out-of-bounds memory access.
There was a problem hiding this comment.
Probably need to assert skip_first_n_tokens < self.chunk_size (we can put it into adapter as well)
Signed-off-by: Samuel Shen <slshen@uchciago.edu>
Signed-off-by: Samuel Shen <slshen@uchciago.edu>
Signed-off-by: Samuel Shen <slshen@uchciago.edu> Signed-off-by: shaoxiawjc <wjc2800@163.com>
Signed-off-by: Samuel Shen <slshen@uchciago.edu> Signed-off-by: Aaron Wu <aaron.wu@dell.com>
Signed-off-by: Samuel Shen <slshen@uchciago.edu>
Signed-off-by: Samuel Shen <slshen@uchciago.edu>
NEEDS: vllm-project/vllm#35831
Tested with the correctness test in the new experimental CI which caught a bug in MP mode: #2663
running lm_eval` gsm8k (300 samples, 50 concurrent, Qwen3-14B) samples identical across two runs. Previously failed consistently on doc_id 114 and sometimes on doc_id 57) which is now passing with this PR and the vllm side adapter PR
we want to fix CUDA stream race during MP retrieve that caused incorrect outputs under high concurrency. when
GetRetrieveMetadataaligns the retrieve start down to the chunk boundary, APC-shared blocks get included in the write target. The LMCache server then writes to these blocks on its own CUDA stream while concurrent requests read them on the vLLM stream.do this by adding a
skip_prefix_n_tokensparameter to the multi-layer KV transfer kernel so the server skips writing to APC-overlapping positions entirely. The kernel handles the offset natively (no Python-side tensor slicing or.contiguous()copy needed which would be the case because T is not the last dimension in 2LTD).