Skip to content

[P/D][Nixl] Make kv cache register compatible with hybrid memory allocator#23079

Merged
njhill merged 6 commits into
vllm-project:mainfrom
sfeng33:pd
Aug 22, 2025
Merged

[P/D][Nixl] Make kv cache register compatible with hybrid memory allocator#23079
njhill merged 6 commits into
vllm-project:mainfrom
sfeng33:pd

Conversation

@sfeng33

@sfeng33 sfeng33 commented Aug 18, 2025

Copy link
Copy Markdown
Collaborator

Purpose

This PR refactors the register_kv_caches method for nixl_connector, so that it works with or without hybrid memory allocator (HMA).
Partially fix #22292.

Background

With HMA, for models with hybrid attention, there can be less the number of kv cache's physical tensors, e.g., gemma-3-4b-it's tensor count drops from 34 to 5, where different layers can the same kv cache tensor.

Model Attention Type # of Layers (No HMA) # of Tensors (With HMA) # of Tensors
gpt-2 Full Attention only 12 12 12
gemma-3-4b-it Full + Sliding Window Attention 34 34 5

In nixl_connector's method register_kv_caches(), the related two functionalities are:

  1. It registers KV cache memory regions with NIXL for direct memory access.
  2. It creates transfer descriptors for individual KV cache blocks.

This PR

  1. Refactors KV cache memory regions registration by iterating through kv caches and registering unique memory address, e.g.
seen_base_addresses = set()
for layer_name, kv_cache_tensor in kv_caches:
    base_addr = cache.data_ptr()
    if base_addr not in seen_base_addresses:
        seen_base_addresses.add(base_addr)
        # register address with NIXL ...

The full implementation is on L742-L761 in nixl_connector.py

  1. Remove the complex logic for block shape, block len, slot len, etc. These info are available in kv_cache_config, which is passed in from model runner. It already took care of the logic for different attention backends and shape calculation in _reshape_kv_cache_tensors(). We can also get the needed info from kv_cache_config regardless of HMA being on and off.

Test Plan

# Unit Test
python -m pytest tests/v1/kv_connector/unit/test_nixl_connector.py -v

# Integration Test
# Tested with prefill decode ratio (1,1) (2,2) (1,2) (2,4) (4,4)
./tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh

# Manual Test
vllm serve google/gemma-3-4b-it \
  --port 8100 \
  --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}' 

@github-actions

Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify Bot added v1 tpu Related to Google TPUs labels Aug 18, 2025

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the KV cache registration in NixlConnector to support a hybrid memory allocator by using KVCacheConfig. This simplifies the logic by removing device- and backend-specific inferences of the cache layout. The changes are well-aligned with the goal, but I've identified two issues. First, the register_kv_caches method in NixlConnector has a signature that is incompatible with its base class, which could cause runtime errors. Second, there's a critical assumption that all KV cache tensors are of the same size, which might not be true with a hybrid allocator and could lead to silent data corruption. I've provided suggestions to address both points.

Comment thread vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py Outdated
Comment thread vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py Outdated
@sfeng33 sfeng33 changed the title [P/D][Nixl] Make kv cache register compatible with hybrid memory allocator [WIP][P/D][Nixl] Make kv cache register compatible with hybrid memory allocator Aug 18, 2025
@sfeng33 sfeng33 marked this pull request as ready for review August 18, 2025 21:15
@sfeng33 sfeng33 changed the title [WIP][P/D][Nixl] Make kv cache register compatible with hybrid memory allocator [P/D][Nixl] Make kv cache register compatible with hybrid memory allocator Aug 18, 2025
@sfeng33

sfeng33 commented Aug 18, 2025

Copy link
Copy Markdown
Collaborator Author

cc @robertgshaw2-redhat @NickLucche @njhill PTAL

@heheda12345

Copy link
Copy Markdown
Collaborator

Can you join #feat-hybrid-allocator-kv-connector in slack to collaborate on kv connector + hybrid allocator?

@njhill

njhill commented Aug 18, 2025

Copy link
Copy Markdown
Member

cc @KuntaiDu

@NickLucche NickLucche left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I like this initial refactoring, thanks for the work!
We still need to figure out the best logic+interface for sliding window attention layers. which is probably going to be the main thing about enabling hma here, but the block_len and kv cache sharing logic look good.

Can we add tests for the kv sharing case to the nixl suite?
Also, are all other nixl tests running fine with the changes?

Comment thread vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py Outdated
Comment thread vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py Outdated
Comment thread vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py Outdated
@mergify mergify Bot removed the tpu Related to Google TPUs label Aug 20, 2025
@sfeng33

sfeng33 commented Aug 20, 2025

Copy link
Copy Markdown
Collaborator Author

I think I like this initial refactoring, thanks for the work! We still need to figure out the best logic+interface for sliding window attention layers. which is probably going to be the main thing about enabling hma here, but the block_len and kv cache sharing logic look good.

Can we add tests for the kv sharing case to the nixl suite? Also, are all other nixl tests running fine with the changes?

Hey @NickLucche, thanks for the review! Before hma can be enabled in nixl connector, there is also work to update the start_load_kv. What is mainly missing is this part from the design doc:

Connector: layout of KV [layer, block_ids, …]
Hybrid allocator: layout of KV [# of groups, block_ids, …]
We need a mapping between [# of groups, block_ids, …] →  [layer, block_ids, …]

I added a unit test. For the integration test, is it mainly run_accuracy_test and run_edge_case_test?

@NickLucche NickLucche left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for adding the test, looks good now!

For the integration test, is it mainly run_accuracy_test and run_edge_case_test

Yep we just have to make sure run_accuracy_test passes. Could you also test out if this PR works fine with heteroTP? Just set PREFILLER_TP_SIZE and DECODER_TP_SIZE.
I have very limited access these days sorry :(

I also left a comment about a small test refactoring in light of upcoming changes, hope that's ok.

Other than that this is LGTM.

Comment thread tests/v1/kv_connector/unit/test_nixl_connector.py Outdated
@sfeng33

sfeng33 commented Aug 20, 2025

Copy link
Copy Markdown
Collaborator Author

Thanks a lot for adding the test, looks good now!

For the integration test, is it mainly run_accuracy_test and run_edge_case_test

Yep we just have to make sure run_accuracy_test passes. Could you also test out if this PR works fine with heteroTP? Just set PREFILLER_TP_SIZE and DECODER_TP_SIZE. I have very limited access these days sorry :(

I also left a comment about a small test refactoring in light of upcoming changes, hope that's ok.

Other than that this is LGTM.

Thanks @NickLucche! The unit test is updated now. I've run run_accuracy_test and validate it passes on prefill decode ratio (1,1) (2,2) (1,2) (2,4) (4,4). Please let me know if there is anything missing.

@NickLucche NickLucche left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the patience @sfeng33 !

@sfeng33

sfeng33 commented Aug 21, 2025

Copy link
Copy Markdown
Collaborator Author

Thanks for the review! Since this PR requires approval from someone with write access, tagging @robertgshaw2-redhat and @njhill for a final look when you get a chance 🙏

@njhill njhill left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@njhill njhill added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 21, 2025
Signed-off-by: sfeng33 <4florafeng@gmail.com>
Signed-off-by: sfeng33 <4florafeng@gmail.com>
Signed-off-by: sfeng33 <4florafeng@gmail.com>
Signed-off-by: sfeng33 <4florafeng@gmail.com>
Signed-off-by: sfeng33 <4florafeng@gmail.com>
Signed-off-by: sfeng33 <4florafeng@gmail.com>
@njhill njhill merged commit 5341565 into vllm-project:main Aug 22, 2025
42 checks passed
Xu-Wenqing pushed a commit to Xu-Wenqing/vllm that referenced this pull request Aug 23, 2025
…cator (vllm-project#23079)

Signed-off-by: sfeng33 <4florafeng@gmail.com>
Signed-off-by: root <xwq391974@alibaba-inc.com>
@sfeng33 sfeng33 deleted the pd branch August 24, 2025 20:44
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 28, 2025
xiao-llm pushed a commit to xiao-llm/vllm that referenced this pull request Aug 28, 2025
…cator (vllm-project#23079)

Signed-off-by: sfeng33 <4florafeng@gmail.com>
Signed-off-by: Xiao Yu <xiao.yu@amd.com>
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Aug 28, 2025
mengxingkongzhouhan pushed a commit to mengxingkongzhouhan/vllm that referenced this pull request Aug 30, 2025
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Sep 3, 2025
zhenwei-intel pushed a commit to zhenwei-intel/vllm that referenced this pull request Sep 10, 2025
ABC12345anouys pushed a commit to ABC12345anouys/vllm that referenced this pull request Sep 25, 2025
mystous pushed a commit to mystous/vllm_hybrid that referenced this pull request May 10, 2026
my-other-github-account pushed a commit to my-other-github-account/vllm that referenced this pull request May 15, 2026
my-other-github-account pushed a commit to my-other-github-account/vllm that referenced this pull request May 15, 2026
0826joyce pushed a commit to 0826joyce/vllm-serving-optimization that referenced this pull request May 19, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature]: Make KVConnector Compatible with HMA

4 participants