Skip to content

Support swa HiCache for unified radix cache#23391

Merged
ispobock merged 15 commits intomainfrom
hybrid_tree/hicache_integrate_swa
May 6, 2026
Merged

Support swa HiCache for unified radix cache#23391
ispobock merged 15 commits intomainfrom
hybrid_tree/hicache_integrate_swa

Conversation

@ispobock
Copy link
Copy Markdown
Collaborator

@ispobock ispobock commented Apr 21, 2026

Motivation

Add SWA support to HiCache on unified radix cache. Follow-up to #23316.

Benchmark

image

w/ HiCache:

SGLANG_ENABLE_UNIFIED_RADIX_TREE=1 \
python3 -m sglang.launch_server \
    --model-path openai/gpt-oss-120b \
    --tp 2 --port 30001 \
    --page-size 64 \
    --attention-backend fa3 --decode-attention-backend fa3 \
    --enable-hierarchical-cache \
    --hicache-ratio 2.0 \
    --hicache-write-policy write_through \
    --hicache-io-backend kernel

w/o HiCache:

SGLANG_ENABLE_UNIFIED_RADIX_TREE=1 \
python3 -m sglang.launch_server \
    --model-path openai/gpt-oss-120b \
    --tp 2 --port 30001 \
    --page-size 64 \
    --attention-backend fa3 --decode-attention-backend fa3

bench:

python3 ./benchmark/hicache/bench_multiturn.py \
    --model-path openai/gpt-oss-120b \
    --port 30001 \
    --disable-random-sample \
    --request-length 2048 --output-length 1024 \
    --num-clients 140 --num-rounds 12 \
    --max-parallel 32 --request-rate 4 \
    --ready-queue-policy random --disable-auto-run \
    --enable-round-barrier \
    --log-file metrics.jsonl

Accuracy

python3 benchmark/gsm8k/bench_sglang.py --num-questions 1400 --parallel 1400
Accuracy: 0.841
Invalid: 0.013
Latency: 130.863 s
Output throughput: 3288.482 token/s

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@baoskee
Copy link
Copy Markdown

baoskee commented Apr 27, 2026

thank you for making this PR! I am excited for L3 file support.

@hzh0425
Copy link
Copy Markdown
Collaborator

hzh0425 commented Apr 28, 2026

thank you for making this PR! I am excited for L3 file support.

This PR currently supports only L2 SWA HiCache.

There are still several L2-support PRs that haven’t been merged yet.

Once the preceding PRs are merged, L3 support will be enabled immediately. @baoskee

@baoskee
Copy link
Copy Markdown

baoskee commented Apr 28, 2026

thank you for making this PR! I am excited for L3 file support.

This PR currently supports only L2 SWA HiCache.

There are still several L2-support PRs that haven’t been merged yet.

Once the preceding PRs are merged, L3 support will be enabled immediately. @baoskee

Amazing, thank you. This is pretty critical to our business right now and you are saving it. I'm running this branch in production right now 😂 but once L3 is added it will reduce inference costs significantly more.

KV cache is the limiting factor right now for our long context inference setup.

@hzh0425
Copy link
Copy Markdown
Collaborator

hzh0425 commented Apr 28, 2026

thank you for making this PR! I am excited for L3 file support.

This PR currently supports only L2 SWA HiCache.
There are still several L2-support PRs that haven’t been merged yet.
Once the preceding PRs are merged, L3 support will be enabled immediately. @baoskee

Amazing, thank you. This is pretty critical to our business right now and you are saving it. I'm running this branch in production right now 😂 but once L3 is added it will reduce inference costs significantly more.

KV cache is the limiting factor right now for our long context inference setup.

Just fixed a small issue—you can go ahead and update. Looking forward to your feedback!

@rarepepi
Copy link
Copy Markdown

thank you for making this PR! I am excited for L3 file support.

This PR currently supports only L2 SWA HiCache.
There are still several L2-support PRs that haven’t been merged yet.
Once the preceding PRs are merged, L3 support will be enabled immediately. @baoskee

Amazing, thank you. This is pretty critical to our business right now and you are saving it. I'm running this branch in production right now 😂 but once L3 is added it will reduce inference costs significantly more.
KV cache is the limiting factor right now for our long context inference setup.

Just fixed a small issue—you can go ahead and update. Looking forward to your feedback!

UnifiedRadixCache errors with assertions — replicas die after 45 mins of accepting traffic.

Setup

  • Hardware: 8× B200 (single node), running 4 replicas at TP=2
  • Model: nvidia/Gemma-4-31B-IT-NVFP4
  • SGLANG_ENABLE_UNIFIED_RADIX_TREE=1
  • --enable-hierarchical-cache
  • --hicache-ratio 1.0
  • --hicache-write-policy write_through_selective
  • --hicache-storage-prefetch-policy wait_complete
  • --hicache-io-backend kernel
  • --hicache-mem-layout page_first
  • --page-size 64
  • --kv-cache-dtype fp8_e4m3
  • --mem-fraction-static 0.92
  • --context-length 32768
  • --chunked-prefill-size 4096
  • --schedule-policy lpm

Crashes (both in unified_radix_cache.py)

File ".../mem_cache/unified_radix_cache.py", line 537, in cache_unfinished_req
req.cache_protected_len <= len(new_indices) + self.page_size - 1
AssertionError: req.cache_protected_len=4288, len(new_indices)=0, page_aligned_len=4608

File ".../mem_cache/unified_radix_cache.py", line 1786, in sanity_check
AssertionError: Sanity check FAILED (2 violations across 980 nodes):
[INV-2] swa host LRU: +S3=set(), +lru={476}
[INV-5] swa in both device and host LRU: {476}

Really appreciate your help thank you!

@baoskee
Copy link
Copy Markdown

baoskee commented May 3, 2026

Hello! Thank you again for making this PR.

I'm running this branch in production in tp=2 across four workers. This could be a race condition since it runs fine for a while but occasionally crashes (possibly when high rates of eviction are happening).

We're getting:

ile "/home/baoskee/root/.venv/lib/python3.12/site-packages/sglang/srt/mem_cache/unified_cache_components/swa_component.py", line 321, in drive_eviction
    self.cache._cascade_evict(x, self, tracker)
  File "/home/baoskee/root/.venv/lib/python3.12/site-packages/sglang/srt/mem_cache/unified_radix_cache.py", line 892, in _cascade_evict
    assert cd.lock_ref == 0
           ^^^^^^^^^^^^^^^^
AssertionError

[2026-05-02 05:55:41 TP1] Scheduler hit an exception: Traceback (most recent call last):
  File "/home/baoskee/root/.venv/lib/python3.12/site-packages/sglang/srt/managers/scheduler.py", line 3807, in run_scheduler_process
    scheduler.run_event_loop()
  File "/home/baoskee/root/.venv/lib/python3.12/site-packages/sglang/srt/managers/scheduler.py", line 1394, in run_event_loop
    dispatch_event_loop(self)
  File "/home/baoskee/root/.venv/lib/python3.12/site-packages/sglang/srt/managers/scheduler.py", line 3677, in dispatch_event_loop
    scheduler.event_loop_overlap()
  File "/home/baoskee/root/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/baoskee/root/.venv/lib/python3.12/site-packages/sglang/srt/managers/scheduler.py", line 1444, in event_loop_overlap
    batch = self.get_next_batch_to_run()
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/baoskee/root/.venv/lib/python3.12/site-packages/sglang/srt/managers/scheduler.py", line 2410, in get_next_batch_to_run
    self.running_batch = self.update_running_batch(self.running_batch)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/baoskee/root/.venv/lib/python3.12/site-packages/sglang/srt/managers/scheduler.py", line 2757, in update_running_batch
    batch.prepare_for_decode()
  File "/home/baoskee/root/.venv/lib/python3.12/site-packages/sglang/srt/managers/schedule_batch.py", line 2225, in prepare_for_decode
    self.out_cache_loc = alloc_for_decode(self, token_per_req=1)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/baoskee/root/.venv/lib/python3.12/site-packages/sglang/srt/mem_cache/common.py", line 444, in alloc_for_decode
    out_cache_loc = alloc_paged_token_slots_decode(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/baoskee/root/.venv/lib/python3.12/site-packages/sglang/srt/mem_cache/common.py", line 405, in alloc_paged_token_slots_decode
    evict_from_tree_cache(tree_cache, num_tokens)
  File "/home/baoskee/root/.venv/lib/python3.12/site-packages/sglang/srt/mem_cache/common.py", line 246, in evict_from_tree_cache
    tree_cache.evict(
  File "/home/baoskee/root/.venv/lib/python3.12/site-packages/sglang/srt/mem_cache/unified_radix_cache.py", line 369, in evict
    component.drive_eviction(params=params, tracker=tracker)
  File "/home/baoskee/root/.venv/lib/python3.12/site-packages/sglang/srt/mem_cache/unified_cache_components/swa_component.py", line 321, in drive_eviction
    self.cache._cascade_evict(x, self, tracker)
  File "/home/baoskee/root/.venv/lib/python3.12/site-packages/sglang/srt/mem_cache/unified_radix_cache.py", line 892, in _cascade_evict
    assert cd.lock_ref == 0
           ^^^^^^^^^^^^^^^^
AssertionError

[2026-05-02 05:55:41] SIGQUIT received. signum=None, frame=None. It usually means one child failed.
Killed

@hzh0425
Copy link
Copy Markdown
Collaborator

hzh0425 commented May 3, 2026

I'll sync some fix code to this branch later—you can give it another try; I've fixed a few issues.
@rarepepi @baoskee

@hzh0425
Copy link
Copy Markdown
Collaborator

hzh0425 commented May 3, 2026

Crashes (both in unified_radix_cache.py)

File ".../mem_cache/unified_radix_cache.py", line 537, in cache_unfinished_req req.cache_protected_len <= len(new_indices) + self.page_size - 1 AssertionError: req.cache_protected_len=4288, len(new_indices)=0, page_aligned_len=4608

File ".../mem_cache/unified_radix_cache.py", line 1786, in sanity_check AssertionError: Sanity check FAILED (2 violations across 980 nodes): [INV-2] swa host LRU: +S3=set(), +lru={476} [INV-5] swa in both device and host LRU: {476}

Really appreciate your help thank you!

This is a known issue, and I'll fix it in this branch. @rarepepi

Base automatically changed from hybrid_tree/hicache_integrate to main May 3, 2026 14:13
@ispobock ispobock force-pushed the hybrid_tree/hicache_integrate_swa branch from 2fcd434 to 5aa2023 Compare May 3, 2026 17:39
@v-shobhit
Copy link
Copy Markdown

Hello! Will this also be applicable to deepseek-v4-pro?

@hzh0425
Copy link
Copy Markdown
Collaborator

hzh0425 commented May 4, 2026

Hello! Will this also be applicable to deepseek-v4-pro?

Yes; based on the swa_hicache branch, we developed HiCache for ds-v4.
We'll wait until swa_hicache is merged, and then rebase the ds branch. @v-shobhit

@ispobock
Copy link
Copy Markdown
Collaborator Author

ispobock commented May 4, 2026

/tag-and-rerun-ci

@github-actions github-actions Bot added the run-ci label May 4, 2026
@ispobock ispobock merged commit eb5f0fb into main May 6, 2026
330 of 360 checks passed
@ispobock ispobock deleted the hybrid_tree/hicache_integrate_swa branch May 6, 2026 14:19
LLThomas pushed a commit to LLThomas/sglang that referenced this pull request May 8, 2026
Co-authored-by: hzh0425 <hzh0425@apache.org>
@baoskee
Copy link
Copy Markdown

baoskee commented May 8, 2026

This doubled our throughput on B200 clusters. Thank you @ispobock and @hzh0425 🙏

@rarepepi
Copy link
Copy Markdown

rarepepi commented May 9, 2026

This doubled our throughput on B200 clusters. Thank you @ispobock and @hzh0425 🙏

Yes thank you!!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

hicache Hierarchical Caching for SGLang high priority run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants