[Disagg] Layer-pipelined KV transfer: overlap RDMA with GPU compute#23515
[Disagg] Layer-pipelined KV transfer: overlap RDMA with GPU compute#23515michael7193 wants to merge 17 commits intosgl-project:mainfrom
Conversation
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
|
CC: @UNIDY2002 Could you check this? I haven't gone through this PR carefully yet, but this seems like a cleaner implementation. |
|
Nice work. We took a different approach in #19931 (callback-driven, per-layer notifications from inside We've been working on Qwen3.5-397B-A17B (hybrid linear attention + GQA + VL), and there are a couple of gaps we can help fill:
We'd like to collaborate on getting Qwen3.5 support into this PR (or a follow-up). |
|
Thanks for the thoughtful review and kind words, @UNIDY2002! Great to hear about your experience with #19931. The callback-driven approach is interesting — glad we converged on similar goals from different angles. Both issues you raised are very practical: forward_split_prefill for Qwen3.5 — Makes total sense. The current implementation assumes models provide forward_split_prefill, so hybrid models like Qwen3.5-397B-A17B would indeed need that. Would love to see your implementation — a follow-up PR sounds perfect. Multimodal fallback — Good catch. Adding a multimodal guard in _get_pipeline_group_size() to fall back to the normal path is straightforward and the right thing to do. Happy to include it in this PR if you'd like to send a patch, or we can handle it in the follow-up together. Very much looking forward to collaborating on Qwen3.5 support. Feel free to ping me anytime! |
39d680d to
155f9b7
Compare
|
@UNIDY2002 Thanks for the catch — applied your @ShangmingCai Gentle ping — this PR is ready for review whenever you have a chance. Summary of what's been done since your last look:
Happy to address any further feedback! |
Overlap RDMA KV transfer with GPU compute by splitting prefill into layer groups and enqueuing per-layer transfers after each group finishes. Transfer(N) overlaps with compute(N+1), reducing TTFT by 14-68% for long prompts (>=3K tokens) in production benchmarks. Key changes: - scheduler.py: add run_batch_pipelined() with grouped forward + KV send - tp_worker.py: add split_init/split_layer/split_sample for layer-wise forward - mooncake/conn.py: add send_kvcache_layer() for single-layer RDMA (MHA+MLA) - prefill.py: per-batch dispatch via _should_use_pipelined() + result handler - fake/conn.py: send_layer stub for warmup requests Gated by SGLANG_PIPELINED_KV_TRANSFER=1 (default off). Configurable via SGLANG_PIPELINE_GROUP_SIZE (default 10) and SGLANG_PIPELINE_MIN_TOKENS (default 3072). Short prompts below threshold use the normal path. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Register SGLANG_PIPELINED_KV_TRANSFER, SGLANG_PIPELINE_GROUP_SIZE, and SGLANG_PIPELINE_MIN_TOKENS in the central environ.py registry. Migrate os.environ.get() calls to envs.XXX.get() for consistency with the rest of the codebase. Add documentation to environment_variables.md. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Instead of a fixed SGLANG_PIPELINE_GROUP_SIZE, automatically compute group_size to keep pipeline iterations in [6, 10] range: - short prompts (<4K): 10 iterations (more overlap) - medium prompts (4K-8K): 8 iterations (sweet spot) - long prompts (>8K): 6 iterations (reduce loop overhead) User can still override via SGLANG_PIPELINE_GROUP_SIZE env var. Rename _should_use_pipelined -> _get_pipeline_group_size (returns 0 to skip, >0 for the group_size to use). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Extend send_kvcache_layer() with optional dst_tp_rank, dst_attn_tp_size, and dst_kv_item_len parameters. When prefill TP != decode TP for MHA models, apply per-token head slicing using vectorized numpy addressing (same math as send_kvcache_slice). MLA remains TP-invariant. Update transfer_worker dispatch to detect TP mismatch in the layer- pipelined branch and forward the extra parameters. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Layer-pipelined KV transfer currently skips Mamba/SWA/NSA state transfer (state is only sent on is_last_chunk, which the per-layer path never triggers). This would cause silent data loss for hybrid models like Jamba, FalconH1, and DeepSeek-R1 with SWA. Add a safety guard in _get_pipeline_group_size() that falls back to the normal (non-pipelined) path when state_type != "none". Per-layer state pipelining will be implemented in a follow-up. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The layer-pipelined path was silently skipping Mamba/SWA/NSA state transfer because send_layer() never passed state_indices. Fix: 1. Add state_indices param to MooncakeKVSender.send_layer() and FakeKVSender.send_layer(). On is_last=True, state_indices are forwarded to add_transfer_request(), which lets transfer_worker call maybe_send_extra() on the last chunk. 2. Add _prepare_pipelined_state_indices() in prefill.py that mirrors the state_indices computation from send_kv_chunk() and attaches the result to each req before the layer loop. 3. Remove the safety guard that forced Mamba/SWA/NSA models to fall back to the normal path — no longer needed. State transfer still happens as a bulk operation on the last chunk (not per-layer), but now overlaps with the last KV group transfer. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Skip layer-pipelined KV transfer when batch contains multimodal inputs (images/audio), as split-prefill is incompatible with general_mm_embed_routine. Falls back to normal path to avoid crashes or silent data corruption. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Xun Sun <UNIDY2002@outlook.com>
Add forward_split_prefill support for Qwen3.5-397B-A17B (hybrid linear attention + GQA + VL), enabling layer-pipelined KV transfer. Changes: - Qwen3_5ForCausalLM.forward_split_prefill: text model split forward - Qwen3_5MoeForConditionalGeneration.forward_split_prefill: VL wrapper with general_mm_embed_routine - Fix is_last -> is_last_chunk parameter name in conn.py Tested on 2x8xH20 with Qwen3.5-397B-A17B-FP8, PD + Mooncake TCP. Authored-by: UNIDY2002
…ward_split_prefill Previously the guard blocked ALL multimodal batches from pipeline mode. Now it only blocks when the model lacks forward_split_prefill (meaning it can't handle multimodal inputs in split mode). Models like Qwen3.5-VL that implement multimodal-aware forward_split_prefill are allowed through. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Add `general_mm_embed_routine` import from `sglang.srt.managers.mm_utils` to fix ruff F821 (undefined name) in `forward_split_prefill` - Remove extra blank lines to satisfy black formatter Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add test_disaggregation_pipelined.py covering: - GSM8K eval correctness with pipelined transfer enabled - Basic single-request generation - Long prompt exercising deeper pipeline overlap - Concurrent request handling (8 parallel prefills) - Fixed group_size configuration path Tests enable SGLANG_PIPELINED_KV_TRANSFER=1 with a low min_tokens threshold to exercise the pipeline path on standard CI eval prompts. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Replace multimodal-only guard with universal hasattr check so that models without forward_split_prefill (e.g. Mamba/hybrid) safely fallback to the normal transfer path instead of crashing. - Implement forward_split_prefill for FalconH1ForCausalLM, enabling layer-pipelined KV+state transfer for Mamba hybrid models. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Add copy_done.synchronize() for staging buffer correctness - Add routed_experts_output/indexer_topk_output finalize() to prevent resource leak in A2A MoE configurations - Use maybe_cache_unfinished_req() instead of direct cache_unfinished_req() to handle HiCache conditional logic correctly Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The forward_split_prefill method uses LogitsProcessorOutput in its return type annotation but it was not imported, causing CI lint failure. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
b5267cb to
2547641
Compare
|
@ShangmingCai Friendly ping — this PR has been rebased onto the latest main (no conflicts). Would appreciate your review when you get a chance. Also, could a maintainer add the |
|
Great! Too busy lately, let me trigger the CI first, will start to review next week. Thank you so much for the PR. |
|
/tag-and-rerun-ci |
There was a problem hiding this comment.
This file has a lint error. Also, is this modification mis-added by cc?
There was a problem hiding this comment.
The falcon_h1.py change is intentional — FalconH1 is a Mamba/Attention hybrid model where SSM conv states need special handling during layer-pipelined transfer (sent once at the final group via maybe_send_extra()). I'll fix the lint error in the next push.
There was a problem hiding this comment.
Does is means that we need to impl this forward_split_prefill for every single model? This might not be a robust design. Will dive in next week.
There was a problem hiding this comment.
Good question! Actually this is not a new pattern we're introducing — there are already 15 models in the upstream codebase that implement forward_split_prefill (llama, qwen, qwen2, qwen3, gemma, gemma2, gemma3, glm4, exaone4, sarvam_moe, qwen2_moe, qwen3_moe, etc.), added for chunked prefill / PP support.
Our layer-pipelined feature simply reuses this existing interface. The design has two layers of safety:
- Guard fallback: If a model doesn't have
forward_split_prefill, the pipelined path is automatically skipped and the request goes through the normal path (no crash, no regression). - Pattern is mechanical: For standard transformer models, the implementation is identical —
embed → layers[start:end] → norm → logits. Only hybrid models (Mamba SSM, hybrid linear attention) need custom logic.
That said, if you'd prefer a more robust approach, we could add a default generic implementation in a base class that works for any standard transformer model, so new models get pipelined support for free without writing any code. Happy to explore that direction if you think it's worthwhile.
Run ruff format to fix lint errors flagged in review. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
|
Fixed the |
Motivation
In PD disaggregation mode, KV cache transfer happens after full prefill computation completes. For long prompts (≥1K tokens), this creates a significant TTFT bottleneck — the decode side must wait for all layers to be computed and then transferred sequentially.
This PR implements layer-pipelined KV transfer: instead of computing all layers then transferring all KV at once, we split layers into groups and transfer each group incrementally. Transfer of group N overlaps with GPU compute of group N+1, significantly reducing TTFT.
Related: #19931 (same direction, different approach)
Key Results
TTFT (ms) — Prompt Length Sweep (C=32, output=256)
TTFT p95 (ms)
Throughput (output tok/s)
Multi-turn Dialogue (16 sessions × 10 turns)
Extreme Stress (C=64, prompt=4096, output=1024)
Design
The feature is controlled by three environment variables (registered in `environ.py`), disabled by default:
How it works
`_get_pipeline_group_size(batch)` — per-batch decision: returns adaptive group_size (>0) or 0 to skip pipeline. A universal guard ensures models without `forward_split_prefill` safely fallback to the normal path. Short prompts also fall back with zero overhead.
`run_batch_pipelined(batch, group_size)` in `Scheduler` — splits forward into layer groups using `model_runner.forward_split_prefill()`, enqueues per-layer KV transfer via `send_layer()` after each group. CUDA events synchronize GPU→transfer ordering. Pre-computes state indices for hybrid models via `_prepare_pipelined_state_indices()`.
`process_batch_result_pipelined_prefill()` — result handler that dispatches to `run_batch_pipelined` instead of `run_batch`, then follows the same downstream logic (including EAGLE spec_info propagation, staging sync, A2A MoE finalization).
`MooncakeKVManager.send_kvcache_layer()` — single-layer RDMA transfer supporting both MHA and MLA architectures via `get_mha_kv_ptrs_with_pp` / `get_mla_kv_ptrs_with_pp`.
`TpModelWorker.forward_batch_generation_split_{init,layer,sample}()` — three-phase split forward: init attention backend → run N layers per call → sample after last group.
Call chain
```
event_loop_normal_disagg_prefill
→ _get_pipeline_group_size(batch)
→ >0: run_batch_pipelined → split_init → [split_layer + send_layer] × N → split_sample
→ 0: run_batch (unchanged)
```
Adaptive group_size (E1)
Instead of a fixed `SGLANG_PIPELINE_GROUP_SIZE`, group_size is automatically tuned based on prompt length to keep pipeline iterations in [6, 10]:
User can still override via `SGLANG_PIPELINE_GROUP_SIZE` env var.
Different TP support (E2)
`send_kvcache_layer()` supports MHA head slicing when prefill TP ≠ decode TP, using vectorized numpy addressing (same math as `send_kvcache_slice`). MLA is TP-invariant and needs no slicing.
Mamba/SWA/NSA state support (E4)
Hybrid models (Jamba, FalconH1, DeepSeek-R1 with SWA) are fully supported. `_prepare_pipelined_state_indices()` pre-computes state indices before the layer loop, then passes them through `send_layer(state_indices=...)` on the last layer to trigger `maybe_send_extra()`. This covers:
No decode-side changes needed — decode already waits for all data (KV + state) before starting.
Universal guard + FalconH1 support (E7)
A universal `hasattr(model, "forward_split_prefill")` guard replaces the previous multimodal-only guard. This ensures:
FalconH1 (Mamba hybrid) now has `forward_split_prefill`, enabling layer-pipelined transfer. Each layer's attention produces KV cache (transferred per-layer via pipeline), while SSM state is sent once at the end via `maybe_send_extra()` (SSM state is fixed-size, independent of sequence length — no benefit from per-layer pipelining).
MTP/EAGLE compatibility (E8)
Reviewed and confirmed that `process_batch_result_pipelined_prefill` correctly propagates EAGLE `spec_info` (`topk_p`, `topk_index`, `hidden_states`) to requests — identical to the normal path. MTP decode-side rollback is purely a decode-phase operation with no interaction with prefill-time pipelined transfer. Also aligned `copy_done.synchronize()`, `routed_experts_output.finalize()`, and `maybe_cache_unfinished_req` with the normal result handler.
Zero regression guarantee
When `SGLANG_PIPELINED_KV_TRANSFER=false` (default):
Code equivalence (v0.4.10.post2 → this PR)
Benchmarks were collected on v0.4.10.post2. This PR ports the same logic to upstream main with these adaptations:
Checklist
Modified Files
Future Work