Skip to content

[Core]: Support HND KV Format#2826

Merged
sammshen merged 28 commits intoLMCache:devfrom
sammshen:permute-contiguous-registration
Mar 26, 2026
Merged

[Core]: Support HND KV Format#2826
sammshen merged 28 commits intoLMCache:devfrom
sammshen:permute-contiguous-registration

Conversation

@sammshen
Copy link
Copy Markdown
Contributor

@sammshen sammshen commented Mar 19, 2026

Test plan

  • E2E test with VLLM_KV_CACHE_LAYOUT=HND (Qwen2.5-3B): basic completion, KV cache
    reuse, deterministic output after cache reset
  • NHD regression test: deterministic output matches baseline
  • E2E with flash infer backend (NL_X_NB_TWO_NH_BS_HS)
  • Multi-process (tensor parallel) HND test

Note

Medium Risk
Touches CUDA KV-cache transfer kernels and the multi-process registration protocol; incorrect layout detection or offset math could corrupt KV data or crash, but changes are localized and guarded by new checks/tests.

Overview
Adds end-to-end support for vLLM’s HND KV cache layout (heads-before-block) alongside existing NHD/MLA formats.

Updates the CUDA transfer path (multi_layer_kv_transfer/single_layer_kv_transfer) to compute correct offsets for new HND GPUKVFormats, adds required head_size plumbing/validation, and exposes the new formats + head_size arg via pybind.

Extends the vLLM integration and multi-process server protocol to pass layout_hints (e.g. {"kv_layout": "HND"}), auto-detect layout at runtime, and permute non-contiguous vLLM HND tensors back to a contiguous physical view before pointer capture/format discovery. Adds unit tests covering HND round-trips and updates existing registration tests for the new payload shape.

Written by Cursor Bugbot for commit 9325b08. This will update automatically on new commits. Configure here.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the system's flexibility by introducing comprehensive support for the Head-Number-Dimension (HND) Key-Value (KV) cache format, alongside the existing Number-Head-Dimension (NHD) format. This enables broader compatibility with various vLLM configurations and ensures that KV cache data is correctly structured and efficiently transferred, particularly for inter-process communication, by dynamically adapting to the detected memory layout and enforcing data contiguity.

Highlights

  • New KV Cache Formats: Introduced two new GPUKVFormat enum values (NL_X_TWO_NB_NH_BS_HS and NL_X_NB_TWO_NH_BS_HS) to explicitly support HND (Head-Number-Dimension) layouts for vLLM flash attention and flash infer.
  • HND Format Handling in CUDA Kernels: Modified the page_buffer_offset CUDA kernel to correctly calculate memory offsets for the new HND formats, including the addition of a head_size parameter to relevant kernel functions and their calls.
  • Dynamic KV Layout Detection and Permutation: Implemented logic to dynamically detect the vLLM KV cache layout ("NHD" or "HND") at runtime and, if HND is detected, permute the KV cache tensors to a contiguous physical shape before inter-process communication (IPC) wrapping.
  • Updated GPU Connector and Utility Functions: Extended various Python utility functions and GPU connector classes to correctly interpret, manage, and transfer KV caches in the new HND formats, ensuring compatibility and proper data handling across the system.
  • IPC Contiguity Enforcement: The CudaIPCWrapper now strictly asserts that tensors are contiguous and have a zero storage offset, relying on the new permutation logic to ensure this condition is met for HND formats before IPC.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Comment thread lmcache/v1/multiprocess/custom_types.py Outdated
# blocks to have coalesced memory accesses
# do NOT blindly call .contiguous() nor .permute()
# we WANT to fail here when our assumptions fail
assert tensor.storage_offset() == 0
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ApostaC changed this back :)

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for the HND KV-cache layout from vLLM. The changes are comprehensive, touching CUDA kernels, Python GPU connectors, and multiprocessing components to handle the new memory layout. Key changes include introducing new GPUKVFormat variants, updating offset calculation logic in CUDA kernels to handle HND, and adding logic to permute HND tensors to a contiguous format for IPC. The CudaIPCWrapper has also been simplified to enforce contiguity. Overall, the changes are well-implemented and consistent. I've provided a couple of suggestions to improve code maintainability by reducing duplication.

Comment thread csrc/mem_kernels.cu
Comment thread lmcache/v1/gpu_connector/utils.py
Comment thread csrc/mem_kernels.cu Outdated
Comment thread lmcache/v1/multiprocess/gpu_context.py
Comment thread lmcache/v1/gpu_connector/gpu_connectors.py Outdated
Comment thread lmcache/v1/gpu_connector/utils.py Outdated
Comment thread lmcache/v1/gpu_connector/gpu_connectors.py Outdated
Comment thread lmcache/v1/gpu_connector/utils.py Outdated
Comment thread lmcache/v1/gpu_connector/gpu_connectors.py Outdated
Comment thread lmcache/v1/gpu_connector/gpu_connectors.py
detected_format = None

if serving_engine == EngineType.VLLM:
kv_layout = layout_hints.get("kv_layout")
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need some really good documentation to ensure that the passer of the layout_hints knows exactly what needs to be passed

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ApostaC introducing a new TypedDict! :)

Comment thread csrc/pybind.cpp
Comment thread lmcache/v1/gpu_connector/utils.py Outdated

# Permute HND tensors to contiguous physical shape before IPC
# wrapping — CudaIPCWrapper asserts contiguity.
if kv_layout == "HND":
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

blanket permute_kv_cahces_to_contiguous

add a warning for now that we detected non contiguous for a case we haven't accounted for

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it HND or some unknown reason

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

self.gpu_kv_format = discover_gpu_kv_format(
kv_caches, EngineType.VLLM, layout_hints=layout_hints
)
if is_hnd(self.gpu_kv_format):
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also do a blanket permutation here.

also: double check whether we shoudl discover first or permute first

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

permute first semantics enforced

kv_caches, EngineType.VLLM, layout_hints=layout_hints
)
assert_is_vllm_flash_attn_or_flash_infer(self.gpu_kv_format)
if is_hnd(self.gpu_kv_format):
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all just permute

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment thread lmcache/v1/gpu_connector/utils.py
Comment thread lmcache/v1/gpu_connector/utils.py Outdated
Comment thread lmcache/v1/gpu_connector/gpu_connectors.py
@sammshen sammshen requested a review from ApostaC March 21, 2026 06:36
Comment thread lmcache/v1/multiprocess/gpu_context.py
@sammshen sammshen added the full Run comprehensive tests on this PR label Mar 23, 2026
Signed-off-by: Samuel Shen <slshen@tensormesh.ai>
Signed-off-by: Samuel Shen <slshen@tensormesh.ai>
Signed-off-by: Samuel Shen <slshen@tensormesh.ai>
Signed-off-by: Samuel Shen <slshen@tensormesh.ai>
Signed-off-by: Samuel Shen <slshen@tensormesh.ai>
Signed-off-by: Samuel Shen <slshen@tensormesh.ai>
Signed-off-by: Samuel Shen <slshen@tensormesh.ai>
Signed-off-by: Samuel Shen <slshen@tensormesh.ai>
Signed-off-by: Samuel Shen <slshen@tensormesh.ai>
@sammshen sammshen force-pushed the permute-contiguous-registration branch from 8171574 to dbcdd55 Compare March 24, 2026 19:11
Comment thread lmcache/v1/gpu_connector/utils.py
Signed-off-by: Samuel Shen <slshen@tensormesh.ai>
Signed-off-by: Samuel Shen <slshen@tensormesh.ai>
Comment thread lmcache/v1/gpu_connector/utils.py
Signed-off-by: Samuel Shen <slshen@tensormesh.ai>
Comment thread lmcache/integration/vllm/vllm_multi_process_adapter.py Outdated
Comment thread lmcache/v1/gpu_connector/gpu_connectors.py Outdated
Signed-off-by: Samuel Shen <slshen@tensormesh.ai>
Signed-off-by: Samuel Shen <slshen@tensormesh.ai>
@sammshen sammshen mentioned this pull request Mar 25, 2026
12 tasks
Copy link
Copy Markdown
Contributor

@ApostaC ApostaC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments regarding high-level placement for the modules.

Comment thread lmcache/integration/vllm/vllm_multi_process_adapter.py Outdated
# First Party
from lmcache.v1.gpu_connector.utils import (
ensure_contiguous_kv_caches,
try_get_vllm_kv_cache_layout,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since try_get_vllm_kv_cache_layout is related to vLLM, can we put it under lmcache/integration/vllm instead of in gpu_connector module?

In this case, layout_hints itself becomes a LMCache-standard interface, and how to set the layout hints should be done by the serving engine integration

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great suggestion! will do

Comment on lines +38 to +44
def _vllm_layout_hints() -> LayoutHints:
"""Build layout_hints dict by querying vLLM at runtime."""
hints: LayoutHints = {}
kv_layout = try_get_vllm_kv_cache_layout()
if kv_layout is not None:
hints["kv_layout"] = kv_layout
return hints
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following up with the above comment, this can be moved to vLLM integration folder

Comment thread lmcache/v1/gpu_connector/gpu_connectors.py Outdated
from lmcache.v1.memory_management import GPUMemoryAllocator # noqa: E501
from lmcache.v1.memory_management import MemoryFormat, MemoryObj
from lmcache.v1.metadata import LMCacheMetadata
from lmcache.v1.multiprocess.custom_types import LayoutHints
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It becomes a bit weird to have gpu_connector importing things from the multiprocess module. Do you have a better idea to place the type definition?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably put it into gpu_connector/utils.py?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great catch, let me double check all of the import locations again!

Signed-off-by: Samuel Shen <slshen@tensormesh.ai>
Comment thread lmcache/v1/gpu_connector/gpu_connectors.py Outdated
Signed-off-by: Samuel Shen <slshen@tensormesh.ai>
Comment thread lmcache/v1/gpu_connector/gpu_connectors.py
Signed-off-by: Samuel Shen <slshen@tensormesh.ai>
Signed-off-by: Samuel Shen <slshen@tensormesh.ai>
Comment thread lmcache/v1/gpu_connector/gpu_connectors.py
Signed-off-by: Samuel Shen <slshen@tensormesh.ai>
Copy link
Copy Markdown

@cursor cursor Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 2 potential issues.

Fix All in Cursor

Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

``"HND"`` — heads before block-size (``VLLM_KV_CACHE_LAYOUT=HND``).
"""

kv_layout: Literal["NHD", "HND"]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing SPDX header comment in test handler helpers file

Low Severity

The new public class LayoutHints has a docstring but omits the Args / Returns / Exceptions sections required by the style guide for all new public types. The docstring describes the class purpose but only documents the kv_layout key informally in a Keys: section rather than as a standard docstring format. This is a minor documentation gap per the project's convention rules for public API documentation.

Fix in Cursor Fix in Web

Triggered by project rule: LMCache Code Review Style Guide

Comment thread lmcache/integration/vllm/vllm_service_factory.py Outdated
Copy link
Copy Markdown
Contributor

@ApostaC ApostaC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Signed-off-by: Samuel Shen <slshen@tensormesh.ai>
@deng451e deng451e self-requested a review March 26, 2026 19:24
@sammshen sammshen merged commit 8769ef4 into LMCache:dev Mar 26, 2026
44 of 46 checks passed
jooho-XCENA pushed a commit to xcena-dev/LMCache that referenced this pull request Apr 2, 2026
* Support the HND format from vLLM

Signed-off-by: Samuel Shen <slshen@tensormesh.ai>
jooho-XCENA pushed a commit to xcena-dev/LMCache that referenced this pull request Apr 2, 2026
* Support the HND format from vLLM

Signed-off-by: Samuel Shen <slshen@tensormesh.ai>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

full Run comprehensive tests on this PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants