Skip to content

[1/N] Support NIXL-based L2 storage in MP mode#2664

Merged
YaoJiayi merged 21 commits intodevfrom
localdev/nixl-adapter
Mar 6, 2026
Merged

[1/N] Support NIXL-based L2 storage in MP mode#2664
YaoJiayi merged 21 commits intodevfrom
localdev/nixl-adapter

Conversation

@YaoJiayi
Copy link
Copy Markdown
Collaborator

@YaoJiayi YaoJiayi commented Mar 2, 2026

What this PR does / why we need it:
Initial support of Nixl-based L2 adapter.

TODOs:

  • Eviction is not handled
  • E2E testing
  • E2E testing with CB
  • Unit test for different nixl backends (currently only have tests for POSIX backend)
  • Documentation on how to enable l2 storage with NIXL in mp mode
  • KV cache persistence

Special notes for your reviewers:

If applicable:

  • this PR contains user facing changes - docs added
  • this PR contains unit tests

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the LMCache's distributed storage capabilities by introducing support for NIXL as an L2 storage backend. It refines memory management by enabling lazy initialization for paged allocators, ensuring that memory is allocated efficiently and correctly based on actual KV cache requirements. These changes facilitate more robust and performant L2 caching in multi-process setups, particularly when dealing with large-scale language models.

Highlights

  • NIXL-based L2 Storage Adapter: Introduced a new L2 adapter, NixlStoreL2Adapter, to support NIXL (NVIDIA In-Memory eXchange Library) for efficient L2 cache storage in multi-process environments. This adapter handles memory registration, data transfers (store/load), and object management using NIXL's capabilities.
  • Lazy Memory Initialization for Paged Allocators: Implemented lazy initialization for paged memory allocators (PagedCpuMemoryAllocator, PagedCpuGpuMemoryAllocator). This allows the CPU memory for these allocators to be allocated only after KV cache shapes and dtypes are known, which is crucial for NIXL's memory registration requirements.
  • New Configuration Options: Added a use_page flag to L1MemoryManagerConfig to enable page-aligned memory allocation. A new NixlStoreL2AdapterConfig was also introduced to configure the NIXL backend, parameters, and pool size.
  • Integration into L1 and Storage Managers: Integrated the lazy_init_memory functionality into L1MemoryManager and StorageManager to ensure that paged allocators are correctly initialized with KV cache metadata during the register_kv_cache process in the multi-process server.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • lmcache/v1/distributed/config.py
    • Added use_page field to L1MemoryManagerConfig to control page-aligned memory allocation.
  • lmcache/v1/distributed/l1_manager.py
    • Imported torch for type hinting.
    • Added lazy_init_memory method to L1MemoryManager to support deferred memory allocation for paged allocators.
  • lmcache/v1/distributed/l2_adapters/base.py
    • Introduced lazy_init_memory as an interface method in L2AdapterInterface for lazy memory initialization.
  • lmcache/v1/distributed/l2_adapters/config.py
    • Defined NixlStoreL2AdapterConfig for configuring NIXL L2 storage, including backend type, parameters, and pool size.
    • Registered NixlStoreL2AdapterConfig with the L2 adapter type system.
  • lmcache/v1/distributed/l2_adapters/nixl_store_l2_adapter.py
    • Added a new file implementing NixlStoreL2Adapter and NixlStorageAgent to provide NIXL-based L2 caching functionality.
  • lmcache/v1/distributed/memory_manager.py
    • Imported MemoryFormat, PagedCpuGpuMemoryAllocator, and PagedCpuMemoryAllocator.
    • Modified create_memory_allocator to instantiate PagedCpuMemoryAllocator when use_page is enabled.
    • Implemented lazy_init_memory in L1MemoryManager to initialize paged allocators with KV cache shapes and dtypes.
    • Updated get_vm_space to handle PagedCpuMemoryAllocator.
    • Adjusted get_memory_usage to correctly report usage for paged allocators and handle uninitialized memory.
  • lmcache/v1/distributed/storage_manager.py
    • Imported torch for type hinting.
    • Added lazy_init_memory method to StorageManager to propagate lazy initialization calls to the L1 manager.
  • lmcache/v1/memory_management.py
    • Introduced PagedAddressManager to provide memory usage statistics for paged allocators.
    • Added address_manager attribute to PagedTensorMemoryAllocator.
    • Implemented PagedCpuMemoryAllocator for paged memory allocation on CPU, supporting lazy initialization.
  • lmcache/v1/multiprocess/server.py
    • Added a call to storage_manager.lazy_init_memory during KV cache registration to ensure proper initialization of paged allocators.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for NIXL-based L2 storage in multi-process mode, including a new NIXL L2 adapter, configuration options, and updates to the memory management system for paged memory allocation. However, several security issues were identified, including potential Denial of Service vulnerabilities due to unvalidated input in the ZMQ server and mismatched configuration keys, as well as insecure file creation in the NIXL storage adapter that could lead to data corruption or symlink attacks. Additionally, the review noted areas for improvement such as silently ignored exceptions, opportunities to use zip(..., strict=True) for robustness, simplification of a configuration class with a dataclass, and a misleading docstring in the memory management code. Addressing these points will significantly improve the robustness, security, and maintainability of the new code.

Comment thread lmcache/v1/multiprocess/server.py Outdated
Comment thread lmcache/v1/distributed/l2_adapters/nixl_store_l2_adapter.py
Comment thread lmcache/v1/distributed/l2_adapters/nixl_store_l2_adapter.py Outdated
Comment thread lmcache/v1/distributed/l2_adapters/config.py Outdated
Comment thread lmcache/v1/distributed/l2_adapters/config.py
Comment thread lmcache/v1/distributed/l2_adapters/nixl_store_l2_adapter.py Outdated
Comment thread lmcache/v1/distributed/l2_adapters/nixl_store_l2_adapter.py Outdated
Comment thread lmcache/v1/distributed/l2_adapters/nixl_store_l2_adapter.py
Comment thread lmcache/v1/memory_management.py Outdated
YaoJiayi and others added 5 commits March 1, 2026 19:33
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: Jiayi Yao <82156730+YaoJiayi@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: Jiayi Yao <82156730+YaoJiayi@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: Jiayi Yao <82156730+YaoJiayi@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: Jiayi Yao <82156730+YaoJiayi@users.noreply.github.com>
Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>
@YaoJiayi YaoJiayi requested review from ApostaC and KuntaiDu March 2, 2026 01:42
YaoJiayi and others added 2 commits March 1, 2026 19:42
Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>
@YaoJiayi YaoJiayi added the full Run comprehensive tests on this PR label Mar 2, 2026
Copy link
Copy Markdown
Contributor

@ApostaC ApostaC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1

Comment thread lmcache/v1/distributed/l2_adapters/nixl_store_l2_adapter.py Outdated
Comment thread lmcache/v1/distributed/l2_adapters/nixl_store_l2_adapter.py Outdated
Comment thread lmcache/v1/distributed/l2_adapters/base.py Outdated
Comment thread lmcache/v1/distributed/memory_manager.py Outdated
Comment thread lmcache/v1/distributed/l2_adapters/nixl_store_l2_adapter.py Outdated
Comment on lines +209 to +213
for i in range(num_pages):
filename = f"obj_{i}_{uuid.uuid4().hex[0:4]}.bin"
tmp_path = os.path.join(file_path, filename)
fd = os.open(tmp_path, flags)
fds.append(fd)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The IBM storage team complained about this this morning. Let's see how we can improve this.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Persistence will be handled in followup prs

Comment on lines +587 to +588
for i, key in enumerate(keys):
if (obj := self._memory_objects.get(key)) is None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean we will never have a "real" lookup in the storage, but just rely on in-proc states?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently we are doing in-process lookup

YaoJiayi and others added 3 commits March 3, 2026 05:39
Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>
Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>
Signed-off-by: Jiayi Yao <82156730+YaoJiayi@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

@ApostaC ApostaC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some minor comments. Otherwise LGTM.

self.nixl_agent.deregister_memory(self.mem_reg_descs)


class NixlStoreL2Adapter(L2AdapterInterface):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add the intialization code into l2_adapters/__init__.py?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh right, the code for init is already merged. Updated

raise RuntimeError("NIXL transfer failed")

def get_storage_indices(self, num_objs: int) -> list[int]:
return self.pool.batched_allocate(num_objs)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do allocate here, but where do we free?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will be handled in the eviction pr. Feels we might need to reuse the eviction code across different adapters and also l1 memory.

Comment on lines +559 to +566
except Exception:
success = False

with self._lock:
for key, storage_obj in zip(keys, storage_objs, strict=False):
self._memory_objects[key] = storage_obj

self._completed_store_tasks[task_id] = success
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the success = False, do we want to update the self._memory_obj? Will this cause an incorrect lookup/load?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah the logic here is a bit confused. Transfer errors and allocation failures are handled in the same way. Updated the logic here.

Comment thread lmcache/v1/distributed/l2_adapters/config.py Outdated
Comment thread lmcache/v1/distributed/l2_adapters/nixl_store_l2_adapter.py
Comment thread lmcache/v1/distributed/l2_adapters/nixl_store_l2_adapter.py Outdated
Comment thread lmcache/v1/distributed/api.py Outdated
Comment thread lmcache/v1/memory_management.py
YaoJiayi and others added 7 commits March 4, 2026 07:16
Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>
Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>
Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>
Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>
Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>
@YaoJiayi YaoJiayi requested a review from ApostaC March 5, 2026 23:48
Copy link
Copy Markdown
Contributor

@KuntaiDu KuntaiDu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Copy Markdown
Contributor

@ApostaC ApostaC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Comment thread lmcache/v1/distributed/memory_manager.py Outdated
Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>
@YaoJiayi YaoJiayi enabled auto-merge (squash) March 6, 2026 20:54
@YaoJiayi YaoJiayi merged commit ecd69ea into dev Mar 6, 2026
27 of 29 checks passed
mauryaavinash95 pushed a commit to mauryaavinash95/LMCache that referenced this pull request Mar 7, 2026
* initial nixl support in mp

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* add tests

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* Update lmcache/v1/memory_management.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: Jiayi Yao <82156730+YaoJiayi@users.noreply.github.com>

* Update lmcache/v1/distributed/l2_adapters/nixl_store_l2_adapter.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: Jiayi Yao <82156730+YaoJiayi@users.noreply.github.com>

* Update lmcache/v1/distributed/l2_adapters/nixl_store_l2_adapter.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: Jiayi Yao <82156730+YaoJiayi@users.noreply.github.com>

* Update lmcache/v1/distributed/l2_adapters/nixl_store_l2_adapter.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: Jiayi Yao <82156730+YaoJiayi@users.noreply.github.com>

* fix misc

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* fix wrong mem allocator

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* not using paged allocator

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* fix tests

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* add init code

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* address several comments

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* add design doc

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* fix tests

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* use lazy import

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* remove get vm space

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

---------

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>
Signed-off-by: Jiayi Yao <82156730+YaoJiayi@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
@YaoJiayi YaoJiayi deleted the localdev/nixl-adapter branch March 10, 2026 17:36
shaoxiawjc pushed a commit to shaoxiawjc/LMCache that referenced this pull request Mar 11, 2026
* initial nixl support in mp

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* add tests

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* Update lmcache/v1/memory_management.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: Jiayi Yao <82156730+YaoJiayi@users.noreply.github.com>

* Update lmcache/v1/distributed/l2_adapters/nixl_store_l2_adapter.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: Jiayi Yao <82156730+YaoJiayi@users.noreply.github.com>

* Update lmcache/v1/distributed/l2_adapters/nixl_store_l2_adapter.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: Jiayi Yao <82156730+YaoJiayi@users.noreply.github.com>

* Update lmcache/v1/distributed/l2_adapters/nixl_store_l2_adapter.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: Jiayi Yao <82156730+YaoJiayi@users.noreply.github.com>

* fix misc

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* fix wrong mem allocator

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* not using paged allocator

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* fix tests

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* add init code

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* address several comments

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* add design doc

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* fix tests

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* use lazy import

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* remove get vm space

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

---------

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>
Signed-off-by: Jiayi Yao <82156730+YaoJiayi@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: shaoxiawjc <wjc2800@163.com>
realAaronWu pushed a commit to realAaronWu/LMCache that referenced this pull request Mar 20, 2026
* initial nixl support in mp

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* add tests

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* Update lmcache/v1/memory_management.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: Jiayi Yao <82156730+YaoJiayi@users.noreply.github.com>

* Update lmcache/v1/distributed/l2_adapters/nixl_store_l2_adapter.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: Jiayi Yao <82156730+YaoJiayi@users.noreply.github.com>

* Update lmcache/v1/distributed/l2_adapters/nixl_store_l2_adapter.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: Jiayi Yao <82156730+YaoJiayi@users.noreply.github.com>

* Update lmcache/v1/distributed/l2_adapters/nixl_store_l2_adapter.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: Jiayi Yao <82156730+YaoJiayi@users.noreply.github.com>

* fix misc

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* fix wrong mem allocator

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* not using paged allocator

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* fix tests

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* add init code

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* address several comments

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* add design doc

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* fix tests

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* use lazy import

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* remove get vm space

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

---------

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>
Signed-off-by: Jiayi Yao <82156730+YaoJiayi@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: Aaron Wu <aaron.wu@dell.com>
jooho-XCENA pushed a commit to xcena-dev/LMCache that referenced this pull request Apr 2, 2026
* initial nixl support in mp

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* add tests

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* Update lmcache/v1/memory_management.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: Jiayi Yao <82156730+YaoJiayi@users.noreply.github.com>

* Update lmcache/v1/distributed/l2_adapters/nixl_store_l2_adapter.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: Jiayi Yao <82156730+YaoJiayi@users.noreply.github.com>

* Update lmcache/v1/distributed/l2_adapters/nixl_store_l2_adapter.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: Jiayi Yao <82156730+YaoJiayi@users.noreply.github.com>

* Update lmcache/v1/distributed/l2_adapters/nixl_store_l2_adapter.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: Jiayi Yao <82156730+YaoJiayi@users.noreply.github.com>

* fix misc

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* fix wrong mem allocator

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* not using paged allocator

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* fix tests

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* add init code

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* address several comments

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* add design doc

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* fix tests

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* use lazy import

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* remove get vm space

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

---------

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>
Signed-off-by: Jiayi Yao <82156730+YaoJiayi@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
jooho-XCENA pushed a commit to xcena-dev/LMCache that referenced this pull request Apr 2, 2026
* initial nixl support in mp

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* add tests

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* Update lmcache/v1/memory_management.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: Jiayi Yao <82156730+YaoJiayi@users.noreply.github.com>

* Update lmcache/v1/distributed/l2_adapters/nixl_store_l2_adapter.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: Jiayi Yao <82156730+YaoJiayi@users.noreply.github.com>

* Update lmcache/v1/distributed/l2_adapters/nixl_store_l2_adapter.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: Jiayi Yao <82156730+YaoJiayi@users.noreply.github.com>

* Update lmcache/v1/distributed/l2_adapters/nixl_store_l2_adapter.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: Jiayi Yao <82156730+YaoJiayi@users.noreply.github.com>

* fix misc

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* fix wrong mem allocator

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* not using paged allocator

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* fix tests

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* add init code

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* address several comments

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* add design doc

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* fix tests

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* use lazy import

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

* remove get vm space

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>

---------

Signed-off-by: YaoJiayi <120040070@link.cuhk.edu.cn>
Signed-off-by: Jiayi Yao <82156730+YaoJiayi@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

full Run comprehensive tests on this PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants