Skip to content

[HiCache] Add L2 prefetch-buffer-only memory mode#20535

Open
vladnosiv wants to merge 56 commits intosgl-project:mainfrom
vladnosiv:l2-buffer-only
Open

[HiCache] Add L2 prefetch-buffer-only memory mode#20535
vladnosiv wants to merge 56 commits intosgl-project:mainfrom
vladnosiv:l2-buffer-only

Conversation

@vladnosiv
Copy link
Copy Markdown
Contributor

@vladnosiv vladnosiv commented Mar 13, 2026

Motivation

Now in HiCache each worker allocates a large exclusive host memory pool (up to hundreds of gb per 1GPU worker).
Popular prefixes get duplicated in every worker's host cache, wasting the majority of the memory budget. Meanwhile, the shared storage backend (MoonCake) is allocated a relatively small pool and remains underutilized.
With buffer_only mode, host memory shrinks to a tiny staging buffer (up to dozens gb per worker), and the freed budget goes entirely to the shared storage pool. Any worker can read data written by any other worker, eliminating duplication.

Modifications

  1. New buffer_only host memory mode: host memory acts as a small transient staging buffer instead of a persistent cache tier. Pages are freed immediately after the async storage write completes. Controlled via --hicache-host-memory-mode buffer_only and --hicache-buffer-pages.
  2. storage_backed flag on TreeNode: tracks whether a node is durably written to external storage, allowing GPU eviction without losing the node from the radix tree. Unified with backuped via new storage_ready property.
  3. Pending write queue: handles backpressure when the small host buffer is full, nodes wait in a queue and drain each scheduler step.

Accuracy Tests

Qwen3/Qwen3-32B-FP8

python benchmark/gsm8k/bench_sglang.py --num-questions 500 --num-shots 48 --parallel 100

With hot cache in MoonCake:

L2 as cache mode:

Accuracy: 0.934
Invalid: 0.000
Latency: 68.659 s
Output throughput: 870.155 token/s

L2 as buffer only:

Accuracy: 0.936
Invalid: 0.000
Latency: 67.984 s
Output throughput: 876.645 token/s

Benchmarking and Profiling

Qwen3-32B-FP8 on 8xH200 with TP1.
Common flags:

--mem-fraction-static 0.8 --page-size 64 --chunked-prefill-size 4096 --max-running-requests 64 --hicache-write-policy write_through --hicache-io-backend direct --hicache-mem-layout page_first_direct --hicache-storage-prefetch-policy wait_complete

Configuration for buffer_only:

  • --hicache-host-memory-mode buffer_only --hicache-buffer-pages 512
  • MoonCake: 1024 gb on the same node

Configuration for cache:

  • --hicache-host-memory-mode cache --hicache-ratio 2.0 (107 gb host memory per worker)
  • MoonCake: 224 gb on the same node
  • Total: 1080 gb

MoonCake Traces bench with toolagent workload and 0.5 slowdown factor.

Metric buffer_only cache delta
Cache hit rate 45.9% 38.7% +7.2%
Storage hits 6.08M tokens 1.53M tokens 4.0x
GPU hits 19.0M tokens 18.9M tokens
E2E P99 29,910 ms 31,468 ms −5%
TTFT P99 3,526 ms 3,996 ms −12%
TPOT P99 432 ms 492 ms −12%
ITL P99 181 ms 206 ms −12%

Signed-off-by: Vladislav Nosivskoy <vladnosiv@gmail.com>
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant enhancement to HiCache's memory management by offering a 'buffer_only' mode for host memory. This mode reconfigures the typically large, exclusive per-worker host memory pools into smaller, transient staging buffers. This strategic shift aims to drastically reduce memory duplication across workers and improve overall memory utilization, allowing more budget to be allocated to shared storage backends. The changes are designed to boost cache hit rates and improve latency, as evidenced by the provided benchmarks.

Highlights

  • New Host Memory Mode: Introduced a buffer_only host memory mode for HiCache, which configures host memory as a small, transient staging buffer instead of a persistent cache tier, optimizing memory usage by reducing duplication across workers.
  • Enhanced Node Tracking: Implemented a storage_ready property on TreeNode to accurately track whether a node's data has been durably written to external storage, unifying previous backuped and storage_backed states.
  • Pending Write Queue: Added a pending write queue mechanism to handle backpressure when the host buffer is full in buffer_only mode, ensuring that nodes awaiting storage writes are processed efficiently during scheduler steps.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/managers/cache_controller.py
    • Adjusted prefetch capacity limit calculation to accommodate the new buffer_only host memory mode.
    • Passed is_nsa_backend flag to the storage configuration.
  • python/sglang/srt/managers/schedule_batch.py
    • Modified cache breakdown computation to correctly account for storage-prefetched pages in both cache and buffer_only modes.
  • python/sglang/srt/managers/scheduler.py
    • Updated prefetch logic to utilize the new storage_ready property for TreeNode.
  • python/sglang/srt/mem_cache/hicache_storage.py
    • Added is_nsa_model to the HiCacheStorageConfig dataclass.
  • python/sglang/srt/mem_cache/hiradix_cache.py
    • Integrated host_memory_mode and buffer_pages into HiRadixCache initialization.
    • Introduced pending_write_queue and pending_write_node_ids for managing write backpressure.
    • Updated backup and write handling logic to differentiate behavior based on host_memory_mode.
    • Added _flush_pending_writes to process queued write requests when host buffer space becomes available.
    • Replaced node.backuped with node.storage_ready in relevant methods.
    • Implemented _insert_helper_storage_device for direct insertion into device cache from storage in buffer_only mode.
  • python/sglang/srt/mem_cache/memory_pool_host.py
    • Introduced host_memory_mode and buffer_pages parameters to control host memory sizing.
    • Removed the strict assertion that host memory must be larger than device memory for buffer_only mode.
    • Added methods for registering tensors, retrieving storage payload suffixes, and handling allocation growth failures.
  • python/sglang/srt/mem_cache/radix_cache.py
    • Added a storage_backed boolean attribute to TreeNode.
    • Introduced a storage_ready property that combines backuped and storage_backed states.
    • Set storage_backed to true for the root node during reset.
  • python/sglang/srt/mem_cache/storage/eic/eic_storage.py
    • Modified register_mem_pool_host to register all relevant tensors from the host memory pool.
    • Generalized zero-copy key and value handling based on the host memory pool's payload count.
  • python/sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py
    • Generalized zero-copy key and value handling based on the host memory pool's payload count.
  • python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py
    • Updated register_mem_pool_host to register all tensors from the host memory pool.
    • Added _get_generic_buffer_meta to support buffer metadata handling for different model types (e.g., NSA).
    • Adjusted batch processing methods to handle variable payload counts per page.
  • python/sglang/srt/mem_cache/storage/nixl/hicache_nixl.py
    • Updated batch_exists and _get_key_list_from_meta to use the host memory pool's storage payload count and suffixes.
    • Modified _batch_get_postprocess to correctly interpret results based on variable payload counts.
    • Changed _batch_set_preprocess to retrieve flat data pages.
  • python/sglang/srt/server_args.py
    • Added --hicache-host-memory-mode and --hicache-buffer-pages command-line arguments.
    • Implemented validation and auto-computation logic for hicache_buffer_pages when buffer_only mode is enabled.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a buffer_only mode for HiCache's host memory, a significant feature for optimizing memory usage in high-throughput scenarios. The changes are extensive, touching core caching logic, memory management, and storage backends. The implementation of a pending write queue to handle backpressure in buffer_only mode is a solid design choice. Additionally, the refactoring in the storage backend interaction to use generic methods for handling different KV cache layouts is a good step towards better maintainability. I've found one critical issue related to a memory leak which should be addressed.

Comment thread python/sglang/srt/mem_cache/hiradix_cache.py
Signed-off-by: Vladislav Nosivskoy <vladnosiv@gmail.com>
Signed-off-by: Vladislav Nosivskoy <vladnosiv@gmail.com>
@stmatengss
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

Signed-off-by: Vladislav Nosivskoy <vladnosiv@gmail.com>
Signed-off-by: Vladislav Nosivskoy <vladnosiv@gmail.com>
Comment thread python/sglang/srt/server_args.py
Signed-off-by: Vladislav Nosivskoy <vladnosiv@gmail.com>
Comment thread python/sglang/srt/server_args.py Outdated
Resolved conflict in hiradix_cache.py: combined upstream's InitLoadBackParams
refactor with branch's buffer_only early return in init_load_back.
Signed-off-by: Vladislav Nosivskoy <vladnosiv@gmail.com>
@vladnosiv
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

@stmatengss
Copy link
Copy Markdown
Collaborator

/rerun-stage stage-b-test-1-gpu-large

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 4, 2026

✅ Triggered stage-b-test-1-gpu-large to run independently (skipping dependencies). View workflow run

@stmatengss
Copy link
Copy Markdown
Collaborator

@xiezhq-hermann Can we merge this PR if most of the CI tests have passed? The AMD CI tests always fail.

@vladnosiv
Copy link
Copy Markdown
Contributor Author

/rerun-stage stage-b-test-1-gpu-large

@vladnosiv
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

@vladnosiv
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

2 similar comments
@vladnosiv
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

@stmatengss
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@vladnosiv
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

Comment thread python/sglang/srt/mem_cache/radix_cache.py Outdated
Comment thread python/sglang/srt/mem_cache/hiradix_cache.py
@stmatengss
Copy link
Copy Markdown
Collaborator

I checked this path locally and I think the reviewer concern is valid.

In buffer_only mode, last_host_node is protected on the host side, but it is still not protected from GPU eviction because lock_ref is not held. So the current implementation appears to leave the prefetch anchor evictable while the prefetch is still in progress.

It would be safer to explicitly pin the anchor for the full prefetch lifetime and release it on completion/abort, plus add a regression test for that lifecycle.

Signed-off-by: Vladislav Nosivskoy <vladnosiv@gmail.com>
Signed-off-by: Vladislav Nosivskoy <vladnosiv@gmail.com>
@stmatengss
Copy link
Copy Markdown
Collaborator

Add protect and release methods to keep the consistency of last_host_node.

    def _protect_prefetch_anchor(self, node: TreeNode) -> None:
        node.protect_host()
        if self.host_memory_mode == "buffer_only":
            self.inc_lock_ref(node)

    def _release_prefetch_anchor(self, node: TreeNode) -> None:
        if self.host_memory_mode == "buffer_only":
            self.dec_lock_ref(node)
        node.release_host()

I wrote a demo to verify the correctness. REF: https://github.com/kvcache-ai/sglang/tree/copilot/pr20535-hicache-tests

Signed-off-by: Vladislav Nosivskoy <vladnosiv@gmail.com>
# Conflicts:
#	python/sglang/srt/mem_cache/hiradix_cache.py
# Conflicts:
#	python/sglang/srt/mem_cache/hiradix_cache.py
Signed-off-by: Vladislav Nosivskoy <vladnosiv@gmail.com>
Signed-off-by: Vladislav Nosivskoy <vladnosiv@gmail.com>
@vladnosiv
Copy link
Copy Markdown
Contributor Author

Relevant tests passed

Снимок экрана 2026-04-30 в 19 45 22 Снимок экрана 2026-04-30 в 19 44 45

@vladnosiv
Copy link
Copy Markdown
Contributor Author

Hi @stmatengss @xiezhq-hermann !
PTAL
Conflicts keep popping up almost daily thanks to the wonderful development activity in hicache :)

@stmatengss
Copy link
Copy Markdown
Collaborator

Hi @stmatengss @xiezhq-hermann !

PTAL

Conflicts keep popping up almost daily thanks to the wonderful development activity in hicache :)

Sorry for delay. I will review it again, and let you know

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

hicache Hierarchical Caching for SGLang high priority run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants