Skip to content

[Hybrid]: Decouple Kernel Block Size from KV Page Size#24486

Merged
vllm-bot merged 63 commits intovllm-project:mainfrom
zhiyuan1i:hybrid-cache-groups
Oct 9, 2025
Merged

[Hybrid]: Decouple Kernel Block Size from KV Page Size#24486
vllm-bot merged 63 commits intovllm-project:mainfrom
zhiyuan1i:hybrid-cache-groups

Conversation

@zhiyuan1i
Copy link
Copy Markdown
Contributor

@zhiyuan1i zhiyuan1i commented Sep 9, 2025

Purpose

This PR introduces a hybrid cache architecture that separates logical kernel block size from
physical page size, enabling more flexible memory management. Key changes include:

  • Added kernel_block_size field to CacheConfig for logical block sizing
  • Enhanced platform-specific configurations for CUDA and ROCm to support hybrid blocks
  • Implemented block table conversion logic between physical and logical representations
  • Added support for different physical/logical block size ratios in V1 worker components

This hybrid model decoupling enables independent development of high-performance operators
without being constrained by linear attention mechanisms like Mamba, addressing performance
bottlenecks discussed in issues #24280 and
#23161.

Test Plan

Added comprehensive tests in tests/v1/worker/test_gpu_model_runner.py to verify:

  • Block table conversion between physical and logical representations
  • Proper handling of different block size ratios
  • Integration with existing GPU model runner functionality
  • Platform-specific configurations for CUDA and ROCm

Test Result

pytest tests/v1/worker/test_gpu_model_runner.py - 20 passes

tests/v1/worker/test_gpu_model_runner.py ....................                                                                                                                        [100%]

===================================================================================== warnings summary =====================================================================================
../../../../opt/conda/envs/vllm-upstream/lib/python3.12/site-packages/torch/cuda/__init__.py:63
  /opt/conda/envs/vllm-upstream/lib/python3.12/site-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
    import pynvml  # type: ignore[import]

<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyPacked has no __module__ attribute

<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyObject has no __module__ attribute

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================================================================== 20 passed, 3 warnings in 89.20s (0:01:29) =========================================================================
sys:1: DeprecationWarning: builtin type swigvarlink has no __module__ attribute

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a hybrid cache architecture to decouple logical and physical block sizes, which is a significant enhancement for memory management. The changes span configuration, platform-specific code, and the core block table management. The implementation in block_table.py appears solid. However, I've identified some critical issues in the tests intended to validate this new functionality. The tests are flawed and do not correctly verify the hybrid block logic, which could mask bugs. Additionally, there's a piece of logic in the GPUModelRunner that could be made more robust. My review focuses on fixing these test and implementation issues to ensure the new feature is reliable and well-tested.

@heheda12345
Copy link
Copy Markdown
Collaborator

Also CC @tdoublep

Copy link
Copy Markdown
Collaborator

@heheda12345 heheda12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed with @zhiyuan1i offline. Two major concerns:

  1. I prefer to calculate kernel block size for each attention backend in gpu_model_runner
  2. would be great if BlockTable.block_table and BlockTable.physical_block_table can be merged into one tensor.

@zhiyuan1i
Copy link
Copy Markdown
Contributor Author

@heheda12345 Thanks for the prompt feedback! I’ve addressed suggestion2 and merged BlockTable.block_table and BlockTable.physical_block_table into a single tensor as recommended. :)

@zhiyuan1i zhiyuan1i force-pushed the hybrid-cache-groups branch 2 times, most recently from 6d1735e to 0b544bf Compare September 9, 2025 14:43
@tjtanaa
Copy link
Copy Markdown
Collaborator

tjtanaa commented Sep 11, 2025

CC @gshtras @hongxiayang as this also affect ROCm

Signed-off-by: lizhiyuan <uniartisan2017@gmail.com>
Signed-off-by: lizhiyuan <uniartisan2017@gmail.com>
@zhiyuan1i zhiyuan1i force-pushed the hybrid-cache-groups branch from 5e0a1a0 to 3e70aa4 Compare October 8, 2025 08:43
@mergify mergify bot removed the needs-rebase label Oct 8, 2025
Signed-off-by: lizhiyuan <uniartisan2017@gmail.com>
Signed-off-by: lizhiyuan <uniartisan2017@gmail.com>
@mergify
Copy link
Copy Markdown

mergify bot commented Oct 9, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @zhiyuan1i.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify
Copy link
Copy Markdown

mergify bot commented Oct 9, 2025

Documentation preview: https://vllm--24486.org.readthedocs.build/en/24486/

Signed-off-by: Zhiyuan Li <uniartisan2017@gmail.com>
Copy link
Copy Markdown
Collaborator

@heheda12345 heheda12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for this enhancement. Follow-ups:

  1. more clean-ups @heheda12345
  2. verify the get_supported_kernel_block_size of each attention backend.

else:
self.reorder_batch_threshold = reorder_batch_threshold_i

def _find_compatible_block_sizes(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(not a blocker) this function may be simplified.

num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes
if isinstance(kv_cache_spec, AttentionSpec):
has_attn = True
kv_manager_block_size = kv_cache_spec.block_size
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(not a blocker) should we use the common block size of all attention groups in the same kv cache group here?


@staticmethod
def get_supported_kernel_block_size() -> list[Union[int, MultipleOf]]:
return [MultipleOf(16)]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically FA3 would support MultipleOf(1) while FA2 would support MultipleOf(16); I dont think its worth handling this though

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build deepseek Related to DeepSeek models documentation Improvements or additions to documentation gpt-oss Related to GPT-OSS models kv-connector performance Performance-related issues qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm speculative-decoding tpu Related to Google TPUs v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.