Skip to content

[Refactor]: Introduce GPUKVFormat#2567

Merged
KuntaiDu merged 39 commits intoLMCache:devfrom
sammshen:gpu-kv-format
Feb 17, 2026
Merged

[Refactor]: Introduce GPUKVFormat#2567
KuntaiDu merged 39 commits intoLMCache:devfrom
sammshen:gpu-kv-format

Conversation

@sammshen
Copy link
Copy Markdown
Contributor

@sammshen sammshen commented Feb 8, 2026

Testing that needs to be done:

  • vllm flash infer
  • vllm flash attention (in correctness tests)
  • vllm MLA (flash MLA and infer MLA)
  • vllm layerwise (broken status: [BUG] Layerwise mode incompatible with remote offloading #2137) <- not testing
  • cacheblend <- not testing (new MP cacheblend coming soon)
  • sglang MHA
  • sglang MHA layerwise
  • sglang MLA (layerwise not supported yet)
  • multiprocess mode (protected by CI)

Sensibility Tests (query twice and look at answers):

These tests are NOT reliable but are the first layer of defense to check the sensibility of the results.

MODEL=deepseek-ai/DeepSeek-V2-Lite-Chat
MODEL=meta-llama/Llama-3.1-8B-Instruct 
curl -X POST http://localhost:8000/v1/completions \
  -H "Content-Type: application/json" \
  -d "{
    \"model\": \"$MODEL\",
    \"prompt\": \"$(printf 'Please Elaborate the significance of KV cache in language models. %.0s' {1..1000})\",
    \"max_tokens\": 100
  }" | jq

Start up:

  1. vllm flash attention
LMCACHE_CHUNK_SIZE=256 \
vllm serve meta-llama/Llama-3.1-8B-Instruct \
    --kv-transfer-config '{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_both"}' \
    --disable-log-requests --no-enable-prefix-caching \
    --attention-backend FLASH_ATTN

Correctly discovered: GPU Format

(EngineCore_DP0 pid=1740335)   - vLLM non-MLA flash attention (utils.py:78:lmcache.v1.gpu_connector.utils)

Outputs look the same after two queries (store and retrieve)

  1. vllm flash infer:
LMCACHE_CHUNK_SIZE=256 \
vllm serve meta-llama/Llama-3.1-8B-Instruct \
    --kv-transfer-config '{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_both"}' \
    --disable-log-requests --no-enable-prefix-caching \
    --attention-backend FLASHINFER

Correctly discovered GPU Format:

(EngineCore_DP0 pid=1751212)   - vLLM non-MLA flash infer (utils.py:86:lmcache.v1.gpu_connector.utils)

Outputs look the same after two queries (store and retrieve)

  1. vllm flash attention MLA
LMCACHE_CHUNK_SIZE=256 \
vllm serve deepseek-ai/DeepSeek-V2-Lite-Chat \
    --kv-transfer-config '{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_both"}' \
    --disable-log-requests --no-enable-prefix-caching \
    --attention-backend FLASH_ATTN_MLA

Correctly discovered GPU Format:

(EngineCore_DP0 pid=1759370)   - vLLM MLA (utils.py:92:lmcache.v1.gpu_connector.utils)

Outputs look the same after two queries (store and retrieve)

  1. vllm flash infer MLA
LMCACHE_CHUNK_SIZE=256 \
vllm serve deepseek-ai/DeepSeek-V2-Lite-Chat \
    --kv-transfer-config '{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_both"}' \
    --disable-log-requests --no-enable-prefix-caching \
    --attention-backend FLASH_ATTN_MLA

Correctly discovered GPU Format:

(EngineCore_DP0 pid=1773093)   - vLLM MLA (utils.py:92:lmcache.v1.gpu_connector.utils)
  1. vllm layerwise
LMCACHE_CHUNK_SIZE=256 \
LMCACHE_USE_LAYERWISE=True \
vllm serve meta-llama/Llama-3.1-8B-Instruct \
    --kv-transfer-config '{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_both"}' \
    --disable-log-requests --no-enable-prefix-caching \
    --attention-backend FLASH_ATTN

Correctly discovered GPU Format:

(EngineCore_DP0 pid=1778516)   - vLLM non-MLA flash attention (utils.py:78:lmcache.v1.gpu_connector.utils)

Stop Iteration Error encountered like #2268 and #2453

(EngineCore_DP0 pid=1778516)     self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output,
(EngineCore_DP0 pid=1778516)     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=1778516)   File "/home/tensormesh/.local/share/uv/python/cpython-3.12.12-linux-x86_64-gnu/lib/python3.12/contextlib.py", line 144, in __exit__
(EngineCore_DP0 pid=1778516)     next(self.gen)
(EngineCore_DP0 pid=1778516) RuntimeError: generator raised StopIteration
  1. SGLang MHA

Tested with:
https://docs.lmcache.ai/getting_started/quickstart.html

Correctly discovered GPU Format:

[2026-02-13 01:46:43,907] LMCache INFO: GPU KV Format: List[2] -> List[num_layers] of [page_buffer_size, num_heads, head_size] (utils.py:95:lmcache.v1.gpu_connector.utils)
[2026-02-13 01:46:43,907] LMCache INFO: Currently used by:
  - SGLang MHA (flash attention and flash infer) (utils.py:100:lmcache.v1.gpu_connector.utils)
  1. SGLang MLA

Tested with:
https://docs.lmcache.ai/getting_started/quickstart.html but with deepseek-ai/DeepSeek-V2-Lite-Chat

broken (with or without this PR):

  File "/home/tensormesh/jiayis-dad/.venv/lib/python3.12/site-packages/sglang/srt/managers/scheduler.py", line 2905, in run_scheduler_process
    scheduler = Scheduler(
                ^^^^^^^^^^
  File "/home/tensormesh/jiayis-dad/.venv/lib/python3.12/site-packages/sglang/srt/managers/scheduler.py", line 336, in __init__
    self.init_cache_with_memory_pool()
  File "/home/tensormesh/jiayis-dad/.venv/lib/python3.12/site-packages/sglang/srt/managers/scheduler.py", line 685, in init_cache_with_memory_pool
    self.tree_cache = LMCRadixCache(
                      ^^^^^^^^^^^^^^
  File "/home/tensormesh/jiayis-dad/.venv/lib/python3.12/site-packages/sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py", line 91, in __init__
    getattr(self.token_to_kv_pool_allocator._kvcache, "k_buffer"),
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'MLATokenToKVPool' object has no attribute 'k_buffer'. Did you mean: 'kv_buffer'?

[2026-02-13 01:49:37] Received sigquit from a child process. It usually means the child failed.
Killed
  1. Multiprocess mode:

Additions to CI

Built off of @ziruiliu's PR: #2308

Samuel Shen added 3 commits February 8, 2026 07:10
Signed-off-by: Samuel Shen <slshen@uchciago.edu>
Signed-off-by: Samuel Shen <slshen@uchciago.edu>
Signed-off-by: Samuel Shen <slshen@uchciago.edu>
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @sammshen, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly refactors the internal handling of GPU Key-Value (KV) cache memory layouts. By introducing a GPUKVFormat enumeration, the system can now explicitly differentiate and manage various KV cache structures, such as those used by vLLM (flash attention, flash infer, MLA) and SGLang (MHA, MLA). This change enhances code clarity, reduces reliance on implicit boolean flags, and provides a more robust framework for future extensions by centralizing format-specific logic within dedicated utility functions and updating core C++ kernels and Python connectors to leverage this new standardized approach.

Highlights

  • Standardized GPU KV Cache Formats: Introduced a new GPUKVFormat enum in C++ to explicitly define and manage various Key-Value (KV) cache memory layouts, including different vLLM and SGLang formats.
  • Refactored Kernel Interfaces: Updated C++ CUDA kernels (single_layer_kv_transfer_kernel, multi_layer_kv_transfer_templated, etc.) to accept the GPUKVFormat enum and a TransferDirection enum, replacing previous boolean flags for clearer and more robust parameter passing.
  • Dynamic Format Discovery Utilities: Implemented Python utility functions (discover_gpu_kv_format, get_num_blocks, get_block_size, etc.) to automatically detect the specific GPUKVFormat from KV cache tensors and extract relevant properties, reducing manual configuration and potential errors.
  • Unified KV Cache Handling: Integrated the new GPUKVFormat across the Python GPU connector classes and the multiprocessing server, streamlining the logic for managing and transferring KV caches regardless of their underlying memory layout.
  • Comprehensive Test Updates: Modified numerous unit tests to incorporate the new GPUKVFormat enum, ensuring compatibility and correctness of the refactored codebase.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • csrc/mem_kernels.cu
    • Replaced bool direction parameters with TransferDirection enum in single_layer_kv_transfer_kernel and single_layer_kv_transfer_sgl_kernel.
    • Modified page_buffer_offset to use GPUKVFormat and block_size for format-specific offset calculations.
    • Added is_mla inline device function for checking MLA formats.
    • Updated multi_layer_kv_transfer_templated and multi_layer_kv_transfer to accept TransferDirection, GPUKVFormat, and block_size.
    • Refactored single_layer_kv_transfer to use GPUKVFormat instead of vllm_two_major and use_mla booleans.
  • csrc/mem_kernels.cuh
    • Introduced enum class GPUKVFormat with detailed descriptions for various KV cache layouts.
    • Updated function signatures for multi_layer_kv_transfer, multi_layer_kv_transfer_unilateral, single_layer_kv_transfer, and single_layer_kv_transfer_sgl to incorporate TransferDirection and GPUKVFormat.
  • csrc/pybind.cpp
    • Exposed the new GPUKVFormat enum to the Python binding.
  • lmcache/v1/gpu_connector/gpu_connectors.py
    • Integrated new utility functions (discover_gpu_kv_format, get_num_blocks, get_block_size, get_page_buffer_size) for dynamic KV cache format detection.
    • Updated calls to C++ kernels (lmc_ops.multi_layer_kv_transfer, lmc_ops.single_layer_kv_transfer) to pass TransferDirection, GPUKVFormat, and block_size.
    • Removed vllm_two_major attribute.
  • lmcache/v1/gpu_connector/utils.py
    • Added discover_gpu_kv_format to identify the specific GPUKVFormat from KV cache tensors.
    • Implemented helper functions (get_num_layers, get_num_blocks, get_block_size, get_page_buffer_size, get_num_heads, get_hidden_dim_size, get_head_size, get_tokens_per_layer, get_elements_per_layer) for extracting format-specific properties.
    • Introduced assert_is_vllm_flash_attn_or_flash_infer and is_mla utility functions.
  • lmcache/v1/multiprocess/server.py
    • Refactored GPUCacheContext.__init__ to use discover_gpu_kv_format and related utility functions for initializing KV cache properties.
    • Updated lmc_ops.multi_layer_kv_transfer calls in store and _retrieve_loop with the new enum arguments.
  • tests/v1/test_gpu_connector.py
    • Updated test parametrizations to use GPUKVFormat enum instead of boolean flags.
    • Modified KV cache generation calls to pass gpu_kv_format.
  • tests/v1/test_mem_kernels.py
    • Removed generate_mla_kv_cache_paged_list_tensors.
    • Updated num_layers usage and lmc_ops kernel calls to align with GPUKVFormat.
  • tests/v1/utils.py
    • Modified generate_kv_cache_paged_list_tensors to dynamically create KV cache tensors based on the provided GPUKVFormat.
    • Removed generate_mla_kv_cache_paged_list_tensors.
Activity
  • The author, sammshen, has provided a checklist of testing tasks in the PR description, indicating ongoing work or planned verification for various vLLM and SGLang configurations.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@sammshen sammshen marked this pull request as draft February 8, 2026 09:01
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant and valuable refactoring by introducing the GPUKVFormat enum, replacing boolean flags for KV cache layouts to improve readability, maintainability, and extensibility for various serving engines like vLLM and SGLang. However, it also introduces or propagates several security vulnerabilities. The multiprocess cache server lacks validation for GPU block indices in ZMQ requests, which can lead to out-of-bounds memory access. Furthermore, CUDA kernels are susceptible to integer overflows when calculating offsets for very large KV caches. Additionally, a critical issue was found in csrc/mem_kernels.cu regarding missing semicolons and potential undefined behavior in the page_buffer_offset function, which will cause compilation errors. Addressing these security vulnerabilities and the critical compilation issue is essential for the robustness and security of the system.

Comment thread lmcache/v1/multiprocess/server.py
Comment thread lmcache/v1/multiprocess/server.py
Comment thread csrc/mem_kernels.cu Outdated
Comment on lines +192 to +209
if (gpu_kv_format == GPUKVFormat::NL_X_2_NB_BS_NH_HS) {
return k_or_v * page_buffer_size * scalars_per_token +
token_idx * scalars_per_token + scalar_offset;
}
// vllm flash infer
if (gpu_kv_format == GPUKVFormat::NL_X_NB_2_BS_NH_HS) {
const int block_idx = token_idx / block_size;
const int block_offset = token_idx % block_size;
return block_idx * 2 * block_size * scalars_per_token +
k_or_v * block_size * scalars_per_token +
block_offset * scalars_per_token + scalar_offset
}
// MLA
// vLLM: NL_X_NB_BS_HS
// SGLang: NL_X_NBBS_1_HS
if (is_mla(gpu_kv_format)) {
return token_idx * scalars_per_token + scalar_offset
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-medium medium

The page_buffer_offset function is susceptible to integer overflows. The calculation of memory offsets uses 32-bit integer multiplication, which for large KV caches, can exceed the maximum value of a 32-bit signed integer. This leads to an incorrect int64_t offset, causing out-of-bounds memory access or data corruption on the GPU. Cast operands to int64_t before multiplication to ensure 64-bit precision. Additionally, there are missing semicolons at the end of two return statements (lines 202 and 208 in the original file), which will cause a compilation error. The function also has undefined behavior if none of the if conditions are met, as there's no final return statement. Consider refactoring to ensure all control paths return a value.

  if (gpu_kv_format == GPUKVFormat::NL_X_2_NB_BS_NH_HS) {
    return (int64_t)k_or_v * page_buffer_size * scalars_per_token +
           (int64_t)token_idx * scalars_per_token + scalar_offset;
  }
  // vllm flash infer
  if (gpu_kv_format == GPUKVFormat::NL_X_NB_2_BS_NH_HS) {
    const int block_idx = token_idx / block_size;
    const int block_offset = token_idx % block_size;
    return (int64_t)block_idx * 2 * block_size * scalars_per_token +
           (int64_t)k_or_v * block_size * scalars_per_token +
           (int64_t)block_offset * scalars_per_token + scalar_offset;
  }
  // MLA
  // vLLM: NL_X_NB_BS_HS
  // SGLang: NL_X_NBBS_1_HS
  if (is_mla(gpu_kv_format)) {
    return (int64_t)token_idx * scalars_per_token + scalar_offset;
  }



# TODO: support MLA
class SGLangLayerwiseGPUConnector(GPUConnectorInterface):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This class is marked with a TODO to support MLA, but it doesn't currently prevent usage with MLA configurations, which could lead to incorrect behavior (e.g., get_shape() would be wrong). To prevent accidental misuse, consider adding a check in _lazy_initialize_buffer to explicitly raise a NotImplementedError if an MLA format is detected.

@sammshen sammshen mentioned this pull request Feb 8, 2026
12 tasks
Samuel Shen added 21 commits February 8, 2026 20:58
Signed-off-by: Samuel Shen <slshen@uchciago.edu>
Signed-off-by: Samuel Shen <slshen@uchciago.edu>
Signed-off-by: Samuel Shen <slshen@uchciago.edu>
Signed-off-by: Samuel Shen <slshen@uchciago.edu>
Signed-off-by: Samuel Shen <slshen@uchciago.edu>
Signed-off-by: Samuel Shen <slshen@uchciago.edu>
Signed-off-by: Samuel Shen <slshen@uchciago.edu>
Signed-off-by: Samuel Shen <slshen@uchciago.edu>
Signed-off-by: Samuel Shen <slshen@uchciago.edu>
Signed-off-by: Samuel Shen <slshen@uchciago.edu>
Signed-off-by: Samuel Shen <slshen@uchciago.edu>
Signed-off-by: Samuel Shen <slshen@uchciago.edu>
Signed-off-by: Samuel Shen <slshen@uchciago.edu>
Signed-off-by: Samuel Shen <slshen@uchciago.edu>
Signed-off-by: Samuel Shen <slshen@uchciago.edu>
Signed-off-by: Samuel Shen <slshen@uchciago.edu>
Signed-off-by: Samuel Shen <slshen@uchciago.edu>
Signed-off-by: Samuel Shen <slshen@uchciago.edu>
@sammshen sammshen marked this pull request as ready for review February 13, 2026 02:40
Signed-off-by: Samuel Shen <slshen@uchciago.edu>
Samuel Shen added 8 commits February 13, 2026 03:28
Signed-off-by: Samuel Shen <slshen@uchciago.edu>
Signed-off-by: Samuel Shen <slshen@uchciago.edu>
Signed-off-by: Samuel Shen <slshen@uchciago.edu>
Signed-off-by: Samuel Shen <slshen@uchciago.edu>
Signed-off-by: Samuel Shen <slshen@uchciago.edu>
Signed-off-by: Samuel Shen <slshen@uchciago.edu>
Signed-off-by: Samuel Shen <slshen@uchciago.edu>
Signed-off-by: Samuel Shen <slshen@uchciago.edu>
@YaoJiayi YaoJiayi self-requested a review February 14, 2026 01:02
Samuel Shen added 2 commits February 14, 2026 06:54
Signed-off-by: Samuel Shen <slshen@uchciago.edu>
Signed-off-by: Samuel Shen <slshen@uchciago.edu>
@sammshen sammshen added the full Run comprehensive tests on this PR label Feb 14, 2026
Samuel Shen added 3 commits February 14, 2026 19:29
Signed-off-by: Samuel Shen <slshen@uchciago.edu>
Signed-off-by: Samuel Shen <slshen@uchciago.edu>
Signed-off-by: Samuel Shen <slshen@uchciago.edu>
@sammshen
Copy link
Copy Markdown
Contributor Author

The comprehensive and multiprocessing tests keep failing. Now this makes me worried there is a real performance degradation from this PR. I will do some thorough testing tomorrow.

Signed-off-by: Samuel Shen <slshen@tensormesh.ai>
@KuntaiDu KuntaiDu enabled auto-merge (squash) February 17, 2026 19:42
Copy link
Copy Markdown
Contributor

@KuntaiDu KuntaiDu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@KuntaiDu KuntaiDu merged commit f7b040e into LMCache:dev Feb 17, 2026
25 checks passed
@sammshen sammshen mentioned this pull request Feb 18, 2026
2 tasks
DongDongJu pushed a commit to DongDongJu/LMCache that referenced this pull request Feb 22, 2026
* initial commit

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* add changes to multiprocess mode

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* unit tests for new format

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix circular import in gpu_connector/utils.py

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix circular import in gpu_connector/utils.py

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* add forward declaration

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* missing semicolons

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* missing semicolons

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* remove trailing comma

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix check_paged_kv_cache_equal

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* add flash infer correctness tests

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* add keyword bindings and remove flashinfer from correctness tests

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* add more backends to single request correctness tests

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* ifx layerwise gpu connector tests

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix non cuda UT

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix UT x2

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix GPUKVFormat enum naming

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* template the gpu_kv_format

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix the correctness tests

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix flipped TransferDirection enum in layerwiseGPU connector

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* call get gpu kv format on unwrapped tensors

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix sglang unit tests

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix sglang flattening order

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix test_multi_layer_kernel

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix test_multi_layer_kernel

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* add logging to correctness tests

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix curl in correctness tests

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* actually don't need clone

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* add contiguous

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fixed unit tests

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* revert correctness tests since they all need BATCH INVARIANCE

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix gpu picking in ci running on ie-users

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* loosen comprehensive from 1.1 to 1.2

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* add non-contiguous kv cache registration

Signed-off-by: Samuel Shen <slshen@tensormesh.ai>

---------

Signed-off-by: Samuel Shen <slshen@uchciago.edu>
Signed-off-by: Samuel Shen <slshen@tensormesh.ai>
Co-authored-by: Samuel Shen <slshen@uchciago.edu>
sammshen added a commit to sammshen/LMCache that referenced this pull request Mar 1, 2026
* initial commit

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* add changes to multiprocess mode

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* unit tests for new format

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix circular import in gpu_connector/utils.py

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix circular import in gpu_connector/utils.py

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* add forward declaration

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* missing semicolons

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* missing semicolons

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* remove trailing comma

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix check_paged_kv_cache_equal

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* add flash infer correctness tests

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* add keyword bindings and remove flashinfer from correctness tests

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* add more backends to single request correctness tests

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* ifx layerwise gpu connector tests

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix non cuda UT

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix UT x2

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix GPUKVFormat enum naming

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* template the gpu_kv_format

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix the correctness tests

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix flipped TransferDirection enum in layerwiseGPU connector

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* call get gpu kv format on unwrapped tensors

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix sglang unit tests

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix sglang flattening order

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix test_multi_layer_kernel

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix test_multi_layer_kernel

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* add logging to correctness tests

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix curl in correctness tests

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* actually don't need clone

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* add contiguous

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fixed unit tests

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* revert correctness tests since they all need BATCH INVARIANCE

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix gpu picking in ci running on ie-users

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* loosen comprehensive from 1.1 to 1.2

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* add non-contiguous kv cache registration

Signed-off-by: Samuel Shen <slshen@tensormesh.ai>

---------

Signed-off-by: Samuel Shen <slshen@uchciago.edu>
Signed-off-by: Samuel Shen <slshen@tensormesh.ai>
Co-authored-by: Samuel Shen <slshen@uchciago.edu>
hlin99 pushed a commit to hlin99/LMCache that referenced this pull request Mar 2, 2026
* initial commit

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* add changes to multiprocess mode

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* unit tests for new format

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix circular import in gpu_connector/utils.py

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix circular import in gpu_connector/utils.py

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* add forward declaration

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* missing semicolons

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* missing semicolons

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* remove trailing comma

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix check_paged_kv_cache_equal

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* add flash infer correctness tests

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* add keyword bindings and remove flashinfer from correctness tests

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* add more backends to single request correctness tests

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* ifx layerwise gpu connector tests

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix non cuda UT

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix UT x2

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix GPUKVFormat enum naming

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* template the gpu_kv_format

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix the correctness tests

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix flipped TransferDirection enum in layerwiseGPU connector

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* call get gpu kv format on unwrapped tensors

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix sglang unit tests

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix sglang flattening order

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix test_multi_layer_kernel

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix test_multi_layer_kernel

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* add logging to correctness tests

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix curl in correctness tests

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* actually don't need clone

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* add contiguous

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fixed unit tests

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* revert correctness tests since they all need BATCH INVARIANCE

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix gpu picking in ci running on ie-users

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* loosen comprehensive from 1.1 to 1.2

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* add non-contiguous kv cache registration

Signed-off-by: Samuel Shen <slshen@tensormesh.ai>

---------

Signed-off-by: Samuel Shen <slshen@uchciago.edu>
Signed-off-by: Samuel Shen <slshen@tensormesh.ai>
Co-authored-by: Samuel Shen <slshen@uchciago.edu>
mauryaavinash95 pushed a commit to mauryaavinash95/LMCache that referenced this pull request Mar 7, 2026
* initial commit

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* add changes to multiprocess mode

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* unit tests for new format

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix circular import in gpu_connector/utils.py

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix circular import in gpu_connector/utils.py

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* add forward declaration

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* missing semicolons

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* missing semicolons

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* remove trailing comma

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix check_paged_kv_cache_equal

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* add flash infer correctness tests

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* add keyword bindings and remove flashinfer from correctness tests

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* add more backends to single request correctness tests

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* ifx layerwise gpu connector tests

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix non cuda UT

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix UT x2

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix GPUKVFormat enum naming

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* template the gpu_kv_format

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix the correctness tests

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix flipped TransferDirection enum in layerwiseGPU connector

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* call get gpu kv format on unwrapped tensors

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix sglang unit tests

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix sglang flattening order

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix test_multi_layer_kernel

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix test_multi_layer_kernel

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* add logging to correctness tests

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix curl in correctness tests

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* actually don't need clone

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* add contiguous

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fixed unit tests

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* revert correctness tests since they all need BATCH INVARIANCE

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix gpu picking in ci running on ie-users

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* loosen comprehensive from 1.1 to 1.2

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* add non-contiguous kv cache registration

Signed-off-by: Samuel Shen <slshen@tensormesh.ai>

---------

Signed-off-by: Samuel Shen <slshen@uchciago.edu>
Signed-off-by: Samuel Shen <slshen@tensormesh.ai>
Co-authored-by: Samuel Shen <slshen@uchciago.edu>
shaoxiawjc pushed a commit to shaoxiawjc/LMCache that referenced this pull request Mar 11, 2026
* initial commit

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* add changes to multiprocess mode

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* unit tests for new format

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix circular import in gpu_connector/utils.py

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix circular import in gpu_connector/utils.py

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* add forward declaration

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* missing semicolons

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* missing semicolons

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* remove trailing comma

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix check_paged_kv_cache_equal

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* add flash infer correctness tests

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* add keyword bindings and remove flashinfer from correctness tests

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* add more backends to single request correctness tests

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* ifx layerwise gpu connector tests

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix non cuda UT

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix UT x2

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix GPUKVFormat enum naming

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* template the gpu_kv_format

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix the correctness tests

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix flipped TransferDirection enum in layerwiseGPU connector

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* call get gpu kv format on unwrapped tensors

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix sglang unit tests

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix sglang flattening order

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix test_multi_layer_kernel

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix test_multi_layer_kernel

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* add logging to correctness tests

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix curl in correctness tests

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* actually don't need clone

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* add contiguous

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fixed unit tests

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* revert correctness tests since they all need BATCH INVARIANCE

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* fix gpu picking in ci running on ie-users

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* loosen comprehensive from 1.1 to 1.2

Signed-off-by: Samuel Shen <slshen@uchciago.edu>

* add non-contiguous kv cache registration

Signed-off-by: Samuel Shen <slshen@tensormesh.ai>

---------

Signed-off-by: Samuel Shen <slshen@uchciago.edu>
Signed-off-by: Samuel Shen <slshen@tensormesh.ai>
Co-authored-by: Samuel Shen <slshen@uchciago.edu>
Signed-off-by: shaoxiawjc <wjc2800@163.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

full Run comprehensive tests on this PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants