Skip to content

[HW: XPU] Enable Layerwise XPU Connector#2611

Merged
DongDongJu merged 11 commits intoLMCache:devfrom
slokesha:slokesha/xpu_layerwise
Mar 25, 2026
Merged

[HW: XPU] Enable Layerwise XPU Connector#2611
DongDongJu merged 11 commits intoLMCache:devfrom
slokesha:slokesha/xpu_layerwise

Conversation

@slokesha
Copy link
Copy Markdown
Contributor

@slokesha slokesha commented Feb 17, 2026

What this PR does / why we need it:
This PR adds full Layerwise KV cache support for XPU connectors, bringing feature parity with the CUDA connector implementation.

It enables LMCache’s use_layerwise=True workflow for XPU devices, supporting:

  • Layerwise KV retrieve (batched_from_gpu)
  • Layerwise KV restore (_batched_to_gpu_gen)
  • Optional GPU staging (use_gpu=True)

Special notes for your reviewers:

If applicable:

  • this PR contains user facing changes - docs added
  • this PR contains unit tests

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @slokesha, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates a new layerwise XPU connector into the LMCache system, significantly enhancing memory management capabilities for XPU devices. This change enables a more granular and optimized approach to handling KV caches, facilitating more efficient data transfer and resource utilization, which is crucial for improving the performance of large language models on Intel GPUs.

Highlights

  • New Layerwise XPU Connector: Introduced a new VLLMPagedMemLayerwiseXPUConnector class designed for efficient layerwise memory management on XPU devices.
  • Conditional Connector Selection: Implemented logic to conditionally instantiate the new layerwise XPU connector based on a config.use_layerwise flag, allowing dynamic selection of the appropriate connector.
  • Generator-based Data Transfer: Developed generator-based batched_to_gpu and batched_from_gpu methods within the layerwise connector, leveraging PyTorch XPU streams and index_copy_/index_select for optimized asynchronous data movement.
  • Comprehensive Unit Tests: Added new unit tests to thoroughly validate the functionality of both the existing non-layerwise and the newly implemented layerwise XPU connectors, ensuring correct data roundtrip operations.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • lmcache/v1/gpu_connector/init.py
    • Imported VLLMPagedMemLayerwiseXPUConnector.
    • Added conditional logic to instantiate VLLMPagedMemLayerwiseXPUConnector if config.use_layerwise is true.
  • lmcache/v1/gpu_connector/xpu_connectors.py
    • Added new type hints for Generator, Iterable, Sequence, Tuple, and Union.
    • Imported GPUConnectorInterface.
    • Implemented _split_token2d_kv helper function to parse 2D token tensors into K and V components.
    • Implemented _get_paged_kv_views helper function to provide flattened views of paged KV caches for indexing operations.
    • Introduced the VLLMPagedMemLayerwiseXPUConnector class, inheriting from GPUConnectorInterface.
    • Defined __init__ and from_metadata methods for the new layerwise connector.
    • Implemented _lazy_initialize_buffer for optional GPU buffer allocation.
    • Overrode to_gpu and from_gpu to raise NotImplementedError, directing usage to batched methods.
    • Implemented _batched_to_gpu_gen for CPU to XPU data transfer using a generator pattern, including staging buffers and layerwise scattering.
    • Implemented batched_from_gpu for XPU to CPU data transfer using a generator pattern, including staging buffers and layerwise gathering.
    • Implemented get_shape to determine the buffer shape based on MLA usage.
  • tests/v1/test_xpu_connector.py
    • Added a new test file dedicated to XPU connector functionality.
    • Included _skip_if_no_xpu utility to skip tests if XPU is unavailable.
    • Implemented _make_unique_slot_mapping helper for generating unique slot mappings in tests.
    • Added test_xpu_connector_roundtrip_non_layerwise to verify the existing VLLMPagedMemXPUConnectorV2.
    • Added test_xpu_connector_roundtrip_layerwise to verify the new VLLMPagedMemLayerwiseXPUConnector using generator-based data transfer.
Activity
  • No human activity has been recorded on this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@slokesha slokesha force-pushed the slokesha/xpu_layerwise branch 2 times, most recently from d9d23e4 to 4dac694 Compare February 17, 2026 06:39
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a VLLMPagedMemLayerwiseXPUConnector to support layerwise KV cache transfers on Intel XPUs using pure PyTorch operations, along with corresponding tests. The implementation is a great step towards broader hardware support.

My review has identified a few areas for improvement:

  • In batched_from_gpu, the device-side staging buffer for the use_gpu=True path is allocated but never used, which is a bug and a missed performance optimization.
  • The same method also contains an inefficient data transfer pattern for slot_mapping inside a loop.
  • The new tests for the layerwise XPU connector only cover the use_gpu=False case, and coverage for the use_gpu=True path should be added.

Addressing these points will improve the correctness, performance, and robustness of the new connector.

Comment thread lmcache/v1/gpu_connector/xpu_connectors.py Outdated
Comment thread lmcache/v1/gpu_connector/xpu_connectors.py Outdated
Comment thread tests/v1/test_xpu_connector.py Outdated
@slokesha slokesha force-pushed the slokesha/xpu_layerwise branch from 4dac694 to 7b669e9 Compare February 18, 2026 17:14
@slokesha slokesha changed the title Enable Layerwise XPU Connector [HW: XPU] Enable Layerwise XPU Connector Feb 18, 2026
@slokesha slokesha force-pushed the slokesha/xpu_layerwise branch 2 times, most recently from 88cbb4b to ecc1ce2 Compare February 20, 2026 08:32
@slokesha slokesha marked this pull request as ready for review February 20, 2026 23:34
@slokesha slokesha force-pushed the slokesha/xpu_layerwise branch 4 times, most recently from 239ff2b to ab30d10 Compare February 26, 2026 21:32
Comment thread lmcache/v1/gpu_connector/xpu_connectors.py Outdated
Comment thread lmcache/v1/gpu_connector/xpu_connectors.py Outdated
@slokesha slokesha force-pushed the slokesha/xpu_layerwise branch 2 times, most recently from 08f2b43 to 30546f3 Compare February 27, 2026 00:54
@slokesha slokesha requested a review from sammshen February 27, 2026 00:54
@DongDongJu DongDongJu self-requested a review February 27, 2026 15:33
@slokesha slokesha force-pushed the slokesha/xpu_layerwise branch from 30546f3 to bd8e439 Compare February 27, 2026 17:37
Copy link
Copy Markdown
Collaborator

@DongDongJu DongDongJu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @slokesha , Thanks for the great work!
I left few questions.
Please add following tests:

  • Multi-chunk starts/ends (e.g., two or three segments) for both layerwise and non-layerwise
  • use_gpu=True test or forcing use_gpu to false when xpu enabled..

Comment thread lmcache/v1/gpu_connector/xpu_connectors.py Outdated
Comment thread lmcache/v1/gpu_connector/xpu_connectors.py Outdated
Comment thread lmcache/v1/gpu_connector/xpu_connectors.py Outdated
@slokesha slokesha force-pushed the slokesha/xpu_layerwise branch 2 times, most recently from bf6e158 to 5a8f875 Compare March 9, 2026 03:13
@slokesha
Copy link
Copy Markdown
Contributor Author

slokesha commented Mar 9, 2026

@DongDongJu @sammshen ,
Thank you for the comments. I’ve made the requested changes and added additional test coverage.

The new tests cover use_gpu=true and multi-chunk scenarios in test_xpu_connector.py.

I’m not completely sure whether the benchmark test is necessary for this PR, but I’ve kept it for now. Happy to remove it if you think it doesn’t belong in the test suite.

@slokesha slokesha requested a review from DongDongJu March 9, 2026 03:19
Copy link
Copy Markdown
Collaborator

@DongDongJu DongDongJu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @slokesha,
I left few comments.

Comment thread lmcache/v1/gpu_connector/__init__.py
Comment thread lmcache/v1/gpu_connector/utils.py
Comment thread lmcache/v1/gpu_connector/xpu_connectors.py Outdated
@slokesha slokesha force-pushed the slokesha/xpu_layerwise branch from a8f715b to 3edd105 Compare March 11, 2026 18:45
@slokesha slokesha requested a review from DongDongJu March 11, 2026 18:45
@slokesha slokesha force-pushed the slokesha/xpu_layerwise branch 2 times, most recently from eabd5dc to bf3d18d Compare March 12, 2026 16:34
@slokesha slokesha force-pushed the slokesha/xpu_layerwise branch from f0336b7 to 08be1d2 Compare March 17, 2026 23:50
Comment thread lmcache/v1/memory_management.py
Copy link
Copy Markdown
Contributor

@sammshen sammshen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Copy Markdown
Collaborator

@DongDongJu DongDongJu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please make the folllowing changes for the use_xpu cleanup!
LGTM

@DongDongJu DongDongJu enabled auto-merge (squash) March 20, 2026 15:03
@github-actions github-actions Bot added the full Run comprehensive tests on this PR label Mar 20, 2026
@sammshen
Copy link
Copy Markdown
Contributor

also please run the pre-commit

Copilot AI added a commit to ftian1/LMCache that referenced this pull request Mar 23, 2026
…MPagedMemLayerwiseXPUConnector

Add VLLMPagedMemLayerwiseXPUConnector class (from PR LMCache#2611)
and helper functions (_split_token2d_kv, _get_head_size_view) to enable
apples-to-apples performance comparison between the two layerwise
connector implementations.

Co-authored-by: ftian1 <16394660+ftian1@users.noreply.github.com>
Agent-Logs-Url: https://github.com/ftian1/LMCache/sessions/9d73c790-ce1d-4dad-a3c8-f9746d8f1fd6
…se_gpu to use_xpu in layerwise XPU connector

Signed-off-by: slokesha <slokeshappa@habana.ai>
auto-merge was automatically disabled March 24, 2026 01:23

Head branch was pushed to by a user without write access

@github-actions github-actions Bot removed the full Run comprehensive tests on this PR label Mar 24, 2026
slokesha and others added 4 commits March 23, 2026 21:24
Signed-off-by: Spurthi Lokeshappa <spurthi.lokeshappa@intel.com>
Signed-off-by: slokesha <slokeshappa@habana.ai>
@slokesha
Copy link
Copy Markdown
Contributor Author

@sammshen and @DongDongJu, Can we merge this PR?

@DongDongJu DongDongJu enabled auto-merge (squash) March 24, 2026 21:25
@github-actions github-actions Bot added the full Run comprehensive tests on this PR label Mar 24, 2026
@sammshen
Copy link
Copy Markdown
Contributor

rerunning CI

@DongDongJu DongDongJu merged commit b7d025a into LMCache:dev Mar 25, 2026
26 checks passed
realAaronWu pushed a commit to realAaronWu/LMCache that referenced this pull request Mar 26, 2026
* Added Layerwise XPU Connector

Signed-off-by: slokesha <slokeshappa@habana.ai>

* Adressed PR Comments

Signed-off-by: slokesha <slokeshappa@habana.ai>

* fix k_tok error

Signed-off-by: slokesha <spurthi.lokeshappa@intel.com>
Signed-off-by: slokesha <slokeshappa@habana.ai>

* Added _get_head_size_view() to take GPUKVFormat enum

Signed-off-by: slokesha <slokeshappa@habana.ai>

* Addresed PR comments

Signed-off-by: slokesha <slokeshappa@habana.ai>

* added multi_chunk_test

Signed-off-by: slokesha <slokeshappa@habana.ai>

* xpu: fix CPU staging pin memory, disk retrieve deadlock, and rename use_gpu to use_xpu in layerwise XPU connector

Signed-off-by: slokesha <slokeshappa@habana.ai>

* Fixed Pre-commit

Signed-off-by: slokesha <slokeshappa@habana.ai>

---------

Signed-off-by: slokesha <slokeshappa@habana.ai>
Signed-off-by: slokesha <spurthi.lokeshappa@intel.com>
Signed-off-by: Spurthi Lokeshappa <spurthi.lokeshappa@intel.com>
deng451e pushed a commit to deng451e/LMCache that referenced this pull request Mar 27, 2026
* Added Layerwise XPU Connector

Signed-off-by: slokesha <slokeshappa@habana.ai>

* Adressed PR Comments

Signed-off-by: slokesha <slokeshappa@habana.ai>

* fix k_tok error

Signed-off-by: slokesha <spurthi.lokeshappa@intel.com>
Signed-off-by: slokesha <slokeshappa@habana.ai>

* Added _get_head_size_view() to take GPUKVFormat enum

Signed-off-by: slokesha <slokeshappa@habana.ai>

* Addresed PR comments

Signed-off-by: slokesha <slokeshappa@habana.ai>

* added multi_chunk_test

Signed-off-by: slokesha <slokeshappa@habana.ai>

* xpu: fix CPU staging pin memory, disk retrieve deadlock, and rename use_gpu to use_xpu in layerwise XPU connector

Signed-off-by: slokesha <slokeshappa@habana.ai>

* Fixed Pre-commit

Signed-off-by: slokesha <slokeshappa@habana.ai>

---------

Signed-off-by: slokesha <slokeshappa@habana.ai>
Signed-off-by: slokesha <spurthi.lokeshappa@intel.com>
Signed-off-by: Spurthi Lokeshappa <spurthi.lokeshappa@intel.com>
jooho-XCENA pushed a commit to xcena-dev/LMCache that referenced this pull request Apr 2, 2026
* Added Layerwise XPU Connector

Signed-off-by: slokesha <slokeshappa@habana.ai>

* Adressed PR Comments

Signed-off-by: slokesha <slokeshappa@habana.ai>

* fix k_tok error

Signed-off-by: slokesha <spurthi.lokeshappa@intel.com>
Signed-off-by: slokesha <slokeshappa@habana.ai>

* Added _get_head_size_view() to take GPUKVFormat enum

Signed-off-by: slokesha <slokeshappa@habana.ai>

* Addresed PR comments

Signed-off-by: slokesha <slokeshappa@habana.ai>

* added multi_chunk_test

Signed-off-by: slokesha <slokeshappa@habana.ai>

* xpu: fix CPU staging pin memory, disk retrieve deadlock, and rename use_gpu to use_xpu in layerwise XPU connector

Signed-off-by: slokesha <slokeshappa@habana.ai>

* Fixed Pre-commit

Signed-off-by: slokesha <slokeshappa@habana.ai>

---------

Signed-off-by: slokesha <slokeshappa@habana.ai>
Signed-off-by: slokesha <spurthi.lokeshappa@intel.com>
Signed-off-by: Spurthi Lokeshappa <spurthi.lokeshappa@intel.com>
jooho-XCENA pushed a commit to xcena-dev/LMCache that referenced this pull request Apr 2, 2026
* Added Layerwise XPU Connector

Signed-off-by: slokesha <slokeshappa@habana.ai>

* Adressed PR Comments

Signed-off-by: slokesha <slokeshappa@habana.ai>

* fix k_tok error

Signed-off-by: slokesha <spurthi.lokeshappa@intel.com>
Signed-off-by: slokesha <slokeshappa@habana.ai>

* Added _get_head_size_view() to take GPUKVFormat enum

Signed-off-by: slokesha <slokeshappa@habana.ai>

* Addresed PR comments

Signed-off-by: slokesha <slokeshappa@habana.ai>

* added multi_chunk_test

Signed-off-by: slokesha <slokeshappa@habana.ai>

* xpu: fix CPU staging pin memory, disk retrieve deadlock, and rename use_gpu to use_xpu in layerwise XPU connector

Signed-off-by: slokesha <slokeshappa@habana.ai>

* Fixed Pre-commit

Signed-off-by: slokesha <slokeshappa@habana.ai>

---------

Signed-off-by: slokesha <slokeshappa@habana.ai>
Signed-off-by: slokesha <spurthi.lokeshappa@intel.com>
Signed-off-by: Spurthi Lokeshappa <spurthi.lokeshappa@intel.com>
ianliuy added a commit to ianliuy/LMCache that referenced this pull request Apr 12, 2026
In layerwise retrieval with LocalCPU backend, the unpin loop at
cache_engine.py:1040-1042 was designed for disk-loaded staging objects
(added in LMCache#2611). However, LocalCPUBackend.batched_get_non_blocking()
returns the same Python object from hot_cache that lookup(pin=True) had
already pinned, causing retrieve_layer() to unpin it once, and then
wait_for_save() to unpin it again via lookup_unpin() (LMCache#2786).

This double unpin drives pin_count to -1 and, more critically, triggers
a premature free() of the memory object (unpin() calls free() when both
pin_count <= 0 and ref_count <= 0).

Fix: guard the unpin with location != 'LocalCPUBackend' so that only
disk-loaded staging objects (LocalDisk, Maru, etc.) are unpinned here.
LocalCPU objects retain their pin until lookup_unpin() releases them in
wait_for_save(), preserving the correct single-free lifecycle.

Fixes LMCache#2954

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
ianliuy added a commit to ianliuy/LMCache that referenced this pull request Apr 13, 2026
In layerwise retrieval with LocalCPU backend, the unpin loop at
cache_engine.py:1040-1042 was designed for disk-loaded staging objects
(added in LMCache#2611). However, LocalCPUBackend.batched_get_non_blocking()
returns the same Python object from hot_cache that lookup(pin=True) had
already pinned, causing retrieve_layer() to unpin it once, and then
wait_for_save() to unpin it again via lookup_unpin() (LMCache#2786).

This double unpin drives pin_count to -1 and, more critically, triggers
a premature free() of the memory object (unpin() calls free() when both
pin_count <= 0 and ref_count <= 0).

Fix: guard the unpin with location != 'LocalCPUBackend' so that only
disk-loaded staging objects (LocalDisk, Maru, etc.) are unpinned here.
LocalCPU objects retain their pin until lookup_unpin() releases them in
wait_for_save(), preserving the correct single-free lifecycle.

Fixes LMCache#2954

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: Yiyang Liu <37043548+ianliuy@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

full Run comprehensive tests on this PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants