Skip to content

[P/D disagg] - support decode side radix cache#19746

Merged
ShangmingCai merged 95 commits intomainfrom
ishan/add-radix-cache-decode
May 1, 2026
Merged

[P/D disagg] - support decode side radix cache#19746
ShangmingCai merged 95 commits intomainfrom
ishan/add-radix-cache-decode

Conversation

@ishandhanani
Copy link
Copy Markdown
Collaborator

@ishandhanani ishandhanani commented Mar 3, 2026

Summary

In PD disaggregation, the decode worker can now use radix cache to reuse shared prefixes and request only the delta KV from prefill instead of transferring the full prefix on every turn.

This is enabled with --disaggregation-decode-enable-radix-cache on the decode server.

For now, this path is supported only with --disaggregation-transfer-backend nixl. server_args.py now rejects other transfer backends early when the decode radix cache flag is enabled. Mooncake support will follow in a separate PR.

Main Changes

  • Decode scheduler
    • Match incoming requests against the decode-side radix tree.
    • Lock matched prefix nodes for the request lifetime.
    • Pre-allocate only the delta KV pages beyond the matched prefix.
  • Decode -> prefill protocol
    • Plumb decode_prefix_len from decode to prefill for the NIXL path.
    • Allow full-prefix hits where decode may need no KV pages transferred.
  • Prefill transfer path
    • Initialize the sender with only the unsent delta pages.
    • Keep the chunked transfer cursor monotonic when decode already has part of the prefix.
    • Skip empty non-last chunks so the sender/receiver chunk protocol stays consistent.
  • Correctness / cleanup
    • Align matched prefix length to page boundaries for paged KV allocators.
    • Guard lock release / cleanup paths for transfer-failure cases.
    • Batch finished prebuilt frees through the free-group path.
  • CLI / config
    • The user-facing switch is --disaggregation-decode-enable-radix-cache.
    • Current validation requires --disaggregation-transfer-backend nixl when that flag is set.

Interface

Enable decode radix cache on the decode worker with:

--disaggregation-mode decode --disaggregation-transfer-backend nixl --disaggregation-decode-enable-radix-cache

Prefill continues to run with --disaggregation-transfer-backend nixl.

Note: DP attention is still experimental here. The flag is allowed, but good cache hit rates require prefix-aware DP routing.

Benchmark

Setup

  • Hardware: 1x NVIDIA B200 node (8 GPUs), single-node PD disaggregation via NIXL
  • Model: Qwen/Qwen3-32B, FP8 KV cache, 3P1D, TP=2 per worker
  • Workload: 20 unique ~50K-token prefixes + ~4.5K suffix (~91% prefix reuse), 1000 requests, concurrency 128

Results

Metric Baseline Decode Radix Cache Improvement
Request throughput (req/s) 1.21 1.59 1.32x
Output token throughput (tok/s) 430 566 1.32x
TTFT p50 (s) 73.2 9.0 8.1x
TTFT avg (s) 77.7 31.6 2.5x
Request latency p50 (s) 99.1 73.4 1.35x
ITL avg (ms) 65.6 130.6 0.50x
Benchmark duration (s) 827 628 1.32x

Decode-side logs show the reason for the throughput gain: baseline decode ran near KV capacity (token_usage ~ 0.99) and only fit ~37 running requests, while decode radix cache reduced duplicate prefix residency (token_usage ~ 0.75) and fit roughly 104-126 running requests. The ITL regression is expected from the larger decode batch.

Test Plan

  • Qwen3-0.6B local PD disagg sanity runs
  • MiniMax-M2.5 1P1D on B200
  • Qwen3-32B 3P1D on B200 (results above)
  • Guard decode radix cache behind nixl in server_args.py
  • Multi-node cross-host testing (RDMA transport)
  • Mooncake transfer backend support (separate PR)

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@ishandhanani ishandhanani changed the title [Draft] [P/D disagg] - support decode side radix cache [P/D disagg] - support decode side radix cache Mar 3, 2026
@dongyibo
Copy link
Copy Markdown

dongyibo commented Mar 3, 2026

@ishandhanani Can this feature be understood as follows:
In a multi-turn dialogue scenario, the first round takes tokens 1, 2, 3 as input and outputs tokens 4, 5, 6.
The second round takes tokens 1, 2, 3, 4, 5, 6, 7, 8, 9 as input and outputs tokens 10, 11, 12.

Current status of pd-disagg:
In the first round, for the decode worker, the generated tokens 4, 5, 6 are not cached, and the KV cache of the input tokens 1, 2, 3 is not saved.
In the second round, the prefill worker needs to send the KV cache for all tokens 1, 2, 3, 4, 5, 6, 7, 8, and 9 to the decode worker.

Based on this PR's implementation:
In the first round, the decode worker saves the KV cache for tokens 1, 2, 3, 4, 5, and 6.
In the second round, the prefill worker only needs to send the KV cache for tokens 7, 8, and 9 to the decode worker.
Is my understanding correct?

@ishandhanani
Copy link
Copy Markdown
Collaborator Author

ishandhanani commented Mar 3, 2026

@ishandhanani Can this feature be understood as follows: In a multi-turn dialogue scenario, the first round takes tokens 1, 2, 3 as input and outputs tokens 4, 5, 6. The second round takes tokens 1, 2, 3, 4, 5, 6, 7, 8, 9 as input and outputs tokens 10, 11, 12.

Current status of pd-disagg: In the first round, for the decode worker, the generated tokens 4, 5, 6 are not cached, and the KV cache of the input tokens 1, 2, 3 is not saved. In the second round, the prefill worker needs to send the KV cache for all tokens 1, 2, 3, 4, 5, 6, 7, 8, and 9 to the decode worker.

Based on this PR's implementation: In the first round, the decode worker saves the KV cache for tokens 1, 2, 3, 4, 5, and 6. In the second round, the prefill worker only needs to send the KV cache for tokens 7, 8, and 9 to the decode worker. Is my understanding correct?

Yep. This is correct

@ishandhanani
Copy link
Copy Markdown
Collaborator Author

/gemini review

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

- Set req.prefix_indices in _pre_alloc so init_next_round_input(None)
  computes extend_input_len correctly from the cached prefix length.
  Without this, prepare_for_prebuilt runs a full-length extend instead
  of a delta extend.

- Always call inc_lock_ref on the matched node (even on empty match)
  to match aggregated scheduler behavior. Prevents lock_ref underflow
  when cache_finished_req unconditionally calls dec_lock_ref.
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@ishandhanani
Copy link
Copy Markdown
Collaborator Author

ishandhanani commented Mar 4, 2026

Next step is testing with a larger model on B200. And then step after (maybe in follow up) is to do the same for mooncake

Comment thread python/sglang/srt/disaggregation/prefill.py Outdated
@dongyibo
Copy link
Copy Markdown

dongyibo commented Mar 4, 2026

@ishandhanani There seems to be a constraint here:
For multiple decode workers, such as when decode is run with DP, it's best if the same DP rank is used for the entire conversation; otherwise, the cached KV cache cannot be utilized?

@nananall
Copy link
Copy Markdown

nananall commented Mar 4, 2026

Could you share the exact command you used to run this? I'd like to reproduce it and test it on my side.

@ishandhanani
Copy link
Copy Markdown
Collaborator Author

ishandhanani commented Mar 4, 2026

@ishandhanani There seems to be a constraint here: For multiple decode workers, such as when decode is run with DP, it's best if the same DP rank is used for the entire conversation; otherwise, the cached KV cache cannot be utilized?

Theres a few things here.

  1. when running with multiple decode workers (standard data parallelism of workers) - I expect the router to pick the right decode worker based on kv load. The dynamo router handles this very well + performantly out of the box
  2. For DP attention - agreed. Right now I have not added support. Need to do this

@ishandhanani
Copy link
Copy Markdown
Collaborator Author

/rerun-stage stage-c-test-8-gpu-h20

@github-actions
Copy link
Copy Markdown
Contributor

✅ Triggered stage-c-test-8-gpu-h20 to run independently (skipping dependencies). View workflow run

@ishandhanani
Copy link
Copy Markdown
Collaborator Author

CI passed for this job

@ishandhanani
Copy link
Copy Markdown
Collaborator Author

/rerun-test test/registered/distributed/test_disaggregation_decode_radix_cache.py

@github-actions
Copy link
Copy Markdown
Contributor

8-gpu-h20 (1 test): View workflow run

cd test/ && python3 registered/distributed/test_disaggregation_decode_radix_cache.py

Comment on lines +44 to +48
def maybe_cache_unfinished_req(req: Req, tree_cache: BasePrefixCache, **kwargs):
if getattr(req, "skip_radix_cache_insert", False):
return

tree_cache.cache_unfinished_req(req, **kwargs)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wonder if we should replace all tree_cache.cache_unfinished_req(req, **kwargs with maybe_cache_unfinished_req?

CC: @xiezhq-hermann

Comment on lines +109 to +112
if match_result.mamba_branching_seqlen is not None:
req.mamba_branching_seqlen = match_result.mamba_branching_seqlen
if match_result.cache_protected_len is not None:
req.cache_protected_len = match_result.cache_protected_len
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These look like new logic, but not used temporarily?

Copy link
Copy Markdown
Collaborator

@ShangmingCai ShangmingCai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Others LGTM, if CI passes.

@ShangmingCai
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@ShangmingCai
Copy link
Copy Markdown
Collaborator

CC: @ByronHsu please help review

@ishandhanani
Copy link
Copy Markdown
Collaborator Author

/rerun-stage stage-c-test-8-gpu-h200

@github-actions
Copy link
Copy Markdown
Contributor

✅ Triggered stage-c-test-8-gpu-h200 to run independently (skipping dependencies).
⚠️ Could not retrieve workflow run URL. Check the Actions tab for progress.

@ShangmingCai
Copy link
Copy Markdown
Collaborator

@ByronHsu Please check this PR if you have time, I think it is good to merge.

@ShangmingCai
Copy link
Copy Markdown
Collaborator

If there are no further comments and suggestion, we will merge this PR today.
CC: @cctry @ByronHsu @xiezhq-hermann

@ByronHsu
Copy link
Copy Markdown
Collaborator

ByronHsu commented May 1, 2026

LGTM. Excited to try this feature for long context PD!

@ShangmingCai
Copy link
Copy Markdown
Collaborator

CI has passed.

image

@ShangmingCai ShangmingCai merged commit 5b7ce41 into main May 1, 2026
264 of 294 checks passed
@ShangmingCai ShangmingCai deleted the ishan/add-radix-cache-decode branch May 1, 2026 13:55
vguduruTT pushed a commit to vguduruTT/sglang that referenced this pull request May 2, 2026
# Aux data is still sent below when is_last=True.
if len(kv_indices) > 0:
notif = (
f"{req.room}_kv_{chunk_id}_{int(is_last)}_{self.kv_args.pp_rank}"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a rebase issue, it reverted a fix for the hang when TP P>D. I will fix it in #23967

@llc-kc
Copy link
Copy Markdown
Contributor

llc-kc commented May 7, 2026

Have you tested low-concurrency scenarios (e.g., 30 concurrent requests)? As you mentioned, the baseline decode can only accommodate ~37 running requests. Consequently, a large number of requests will be queued during the decode preallocation phase, which leads to a higher TTFT for the baseline setup.
Tests under low concurrency can fairly reflect the performance improvements brought by Delta KV Cache transmission.
In contrast, the results from high-concurrency benchmarks primarily highlight the benefits of higher concurrency enabled by decode Radix Cache, while the advantages of Delta KV Cache transmission remain obscured.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.