Skip to content

Refactor KV cache shape/dtype extraction for robustness#2537

Merged
deng451e merged 16 commits intoLMCache:devfrom
hlin99:kv_type_refactor
Mar 19, 2026
Merged

Refactor KV cache shape/dtype extraction for robustness#2537
deng451e merged 16 commits intoLMCache:devfrom
hlin99:kv_type_refactor

Conversation

@hlin99
Copy link
Copy Markdown
Contributor

@hlin99 hlin99 commented Feb 4, 2026

What this PR does / why we need it:

  • Support kv_cache as a tuple or list of tensors (e.g., HPU K/V caches).
  • Skip None tensors when extracting shape and dtype.
  • Preserve original behavior for single tensor inputs.
  • Improves robustness and maintainability without changing existing logic.

- Support `kv_cache` as a tuple or list of tensors (e.g., HPU K/V caches).
- Skip None tensors when extracting shape and dtype.
- Preserve original behavior for single tensor inputs.
- Improves robustness and maintainability without changing existing logic.

Signed-off-by: Tony Lin <tony.lin@intel.com>
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @hlin99, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request refactors the KV cache shape and dtype extraction mechanism to enhance its robustness and flexibility. By accommodating kv_cache inputs as tuples or lists of tensors and gracefully handling None values, the changes ensure broader compatibility, particularly with systems like HPU K/V caches. The update improves maintainability without altering the existing functional logic for single tensor inputs.

Highlights

  • Enhanced KV Cache Flexibility: The system now supports kv_cache inputs as either a single torch.Tensor, a tuple of tensors, or a list of tensors, improving compatibility with diverse hardware configurations like HPU K/V caches.
  • Robust Shape/Dtype Extraction: The logic for extracting shape and dtype has been made more robust by skipping None tensors within iterable kv_cache inputs, ensuring that valid tensor information is always used.
  • Preserved Existing Behavior: The refactoring ensures that the original behavior for single torch.Tensor inputs remains unchanged, maintaining backward compatibility while introducing new capabilities.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • lmcache/v1/kv_layer_groups.py
    • Modified the build_kv_layer_groups method to normalize kv_cache inputs into an iterable (list or tuple).
    • Implemented a loop to find the first non-None tensor within the kv_cache iterable to correctly extract its shape and dtype.
Activity
  • The pull request was created by hlin99.
  • No human activity (comments, reviews, or progress updates) has been recorded yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the KV cache shape and dtype extraction to be more robust by supporting kv_cache as a tuple or list of tensors and skipping None values. The change is well-contained and improves maintainability. I've added one comment with a suggestion to simplify the implementation further and to update the related type hint for better code clarity.

Comment thread lmcache/v1/kv_layer_groups.py Outdated
Signed-off-by: Tony Lin <tony.lin@intel.com>
@hlin99
Copy link
Copy Markdown
Contributor Author

hlin99 commented Feb 21, 2026

Hi @maobaolong would you mind to take a look at this one? Thanks

Copy link
Copy Markdown
Collaborator

@maobaolong maobaolong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hlin99 Thanks for introduce the TPU case and supported to the existing kv_layer_group. It would be very helpful if you can add a comment to explain the TPU use case and how the layer tensors shape like for TPU. I guess there are list of same shape tensors for each layer for TPU. Feel free to correct me if I misunderstood.

@hlin99
Copy link
Copy Markdown
Contributor Author

hlin99 commented Mar 3, 2026

@hlin99 Thanks for introduce the TPU case and supported to the existing kv_layer_group. It would be very helpful if you can add a comment to explain the TPU use case and how the layer tensors shape like for TPU. I guess there are list of same shape tensors for each layer for TPU. Feel free to correct me if I misunderstood.

thx @maobaolong for the comments. i will update the PR soon

@hlin99 hlin99 force-pushed the kv_type_refactor branch from 9fc97eb to 21636f5 Compare March 4, 2026 04:37
@hlin99
Copy link
Copy Markdown
Contributor Author

hlin99 commented Mar 4, 2026

@hlin99 Thanks for introduce the TPU case and supported to the existing kv_layer_group. It would be very helpful if you can add a comment to explain the TPU use case and how the layer tensors shape like for TPU. I guess there are list of same shape tensors for each layer for TPU. Feel free to correct me if I misunderstood.

thx @maobaolong for the comments. i will update the PR soon

hi @maobaolong added comments for possible kv_cache types on non cuda alike devices. would you take a look again? thanks for your time.

@maobaolong
Copy link
Copy Markdown
Collaborator

@hlin99 This LGTM. @sammshen Would you like to take another look?

Copy link
Copy Markdown
Contributor

@sammshen sammshen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@deng451e deng451e enabled auto-merge (squash) March 18, 2026 04:42
@github-actions github-actions Bot added the full Run comprehensive tests on this PR label Mar 18, 2026
@deng451e deng451e merged commit 82b0d8b into LMCache:dev Mar 19, 2026
26 of 28 checks passed
hyunyul-XCENA pushed a commit to xcena-dev/LMCache that referenced this pull request Mar 20, 2026
* Refactor KV cache shape/dtype extraction for robustness

- Support `kv_cache` as a tuple or list of tensors (e.g., HPU K/V caches).
- Skip None tensors when extracting shape and dtype.
- Preserve original behavior for single tensor inputs.
- Improves robustness and maintainability without changing existing logic.

Signed-off-by: Tony Lin <tony.lin@intel.com>

* changes to address gemini's comments

Signed-off-by: Tony Lin <tony.lin@intel.com>

* add comments for possible kv_cache types on non cuda alike devices

Signed-off-by: Tony Lin <tony.lin@intel.com>

* streamline the logic and clear comments

Signed-off-by: Tony Lin <tony.lin@intel.com>

---------

Signed-off-by: Tony Lin <tony.lin@intel.com>
realAaronWu pushed a commit to realAaronWu/LMCache that referenced this pull request Mar 20, 2026
* Refactor KV cache shape/dtype extraction for robustness

- Support `kv_cache` as a tuple or list of tensors (e.g., HPU K/V caches).
- Skip None tensors when extracting shape and dtype.
- Preserve original behavior for single tensor inputs.
- Improves robustness and maintainability without changing existing logic.

Signed-off-by: Tony Lin <tony.lin@intel.com>

* changes to address gemini's comments

Signed-off-by: Tony Lin <tony.lin@intel.com>

* add comments for possible kv_cache types on non cuda alike devices

Signed-off-by: Tony Lin <tony.lin@intel.com>

* streamline the logic and clear comments

Signed-off-by: Tony Lin <tony.lin@intel.com>

---------

Signed-off-by: Tony Lin <tony.lin@intel.com>
Signed-off-by: Aaron Wu <aaron.wu@dell.com>
deng451e pushed a commit to deng451e/LMCache that referenced this pull request Mar 21, 2026
* Refactor KV cache shape/dtype extraction for robustness

- Support `kv_cache` as a tuple or list of tensors (e.g., HPU K/V caches).
- Skip None tensors when extracting shape and dtype.
- Preserve original behavior for single tensor inputs.
- Improves robustness and maintainability without changing existing logic.

Signed-off-by: Tony Lin <tony.lin@intel.com>

* changes to address gemini's comments

Signed-off-by: Tony Lin <tony.lin@intel.com>

* add comments for possible kv_cache types on non cuda alike devices

Signed-off-by: Tony Lin <tony.lin@intel.com>

* streamline the logic and clear comments

Signed-off-by: Tony Lin <tony.lin@intel.com>

---------

Signed-off-by: Tony Lin <tony.lin@intel.com>
@hlin99 hlin99 deleted the kv_type_refactor branch March 24, 2026 02:58
deng451e pushed a commit to deng451e/LMCache that referenced this pull request Mar 25, 2026
* Refactor KV cache shape/dtype extraction for robustness

- Support `kv_cache` as a tuple or list of tensors (e.g., HPU K/V caches).
- Skip None tensors when extracting shape and dtype.
- Preserve original behavior for single tensor inputs.
- Improves robustness and maintainability without changing existing logic.

Signed-off-by: Tony Lin <tony.lin@intel.com>

* changes to address gemini's comments

Signed-off-by: Tony Lin <tony.lin@intel.com>

* add comments for possible kv_cache types on non cuda alike devices

Signed-off-by: Tony Lin <tony.lin@intel.com>

* streamline the logic and clear comments

Signed-off-by: Tony Lin <tony.lin@intel.com>

---------

Signed-off-by: Tony Lin <tony.lin@intel.com>
deng451e pushed a commit to deng451e/LMCache that referenced this pull request Mar 27, 2026
* Refactor KV cache shape/dtype extraction for robustness

- Support `kv_cache` as a tuple or list of tensors (e.g., HPU K/V caches).
- Skip None tensors when extracting shape and dtype.
- Preserve original behavior for single tensor inputs.
- Improves robustness and maintainability without changing existing logic.

Signed-off-by: Tony Lin <tony.lin@intel.com>

* changes to address gemini's comments

Signed-off-by: Tony Lin <tony.lin@intel.com>

* add comments for possible kv_cache types on non cuda alike devices

Signed-off-by: Tony Lin <tony.lin@intel.com>

* streamline the logic and clear comments

Signed-off-by: Tony Lin <tony.lin@intel.com>

---------

Signed-off-by: Tony Lin <tony.lin@intel.com>
jooho-XCENA pushed a commit to xcena-dev/LMCache that referenced this pull request Apr 2, 2026
* Refactor KV cache shape/dtype extraction for robustness

- Support `kv_cache` as a tuple or list of tensors (e.g., HPU K/V caches).
- Skip None tensors when extracting shape and dtype.
- Preserve original behavior for single tensor inputs.
- Improves robustness and maintainability without changing existing logic.

Signed-off-by: Tony Lin <tony.lin@intel.com>

* changes to address gemini's comments

Signed-off-by: Tony Lin <tony.lin@intel.com>

* add comments for possible kv_cache types on non cuda alike devices

Signed-off-by: Tony Lin <tony.lin@intel.com>

* streamline the logic and clear comments

Signed-off-by: Tony Lin <tony.lin@intel.com>

---------

Signed-off-by: Tony Lin <tony.lin@intel.com>
jooho-XCENA pushed a commit to xcena-dev/LMCache that referenced this pull request Apr 2, 2026
* Refactor KV cache shape/dtype extraction for robustness

- Support `kv_cache` as a tuple or list of tensors (e.g., HPU K/V caches).
- Skip None tensors when extracting shape and dtype.
- Preserve original behavior for single tensor inputs.
- Improves robustness and maintainability without changing existing logic.

Signed-off-by: Tony Lin <tony.lin@intel.com>

* changes to address gemini's comments

Signed-off-by: Tony Lin <tony.lin@intel.com>

* add comments for possible kv_cache types on non cuda alike devices

Signed-off-by: Tony Lin <tony.lin@intel.com>

* streamline the logic and clear comments

Signed-off-by: Tony Lin <tony.lin@intel.com>

---------

Signed-off-by: Tony Lin <tony.lin@intel.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

full Run comprehensive tests on this PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants