Skip to content

[Disagg] Layer-pipelined KV transfer: overlap RDMA with GPU compute#23515

Open
michael7193 wants to merge 17 commits intosgl-project:mainfrom
michael7193:feature/layer-pipelined-kv-transfer
Open

[Disagg] Layer-pipelined KV transfer: overlap RDMA with GPU compute#23515
michael7193 wants to merge 17 commits intosgl-project:mainfrom
michael7193:feature/layer-pipelined-kv-transfer

Conversation

@michael7193
Copy link
Copy Markdown

@michael7193 michael7193 commented Apr 23, 2026

Motivation

In PD disaggregation mode, KV cache transfer happens after full prefill computation completes. For long prompts (≥1K tokens), this creates a significant TTFT bottleneck — the decode side must wait for all layers to be computed and then transferred sequentially.

This PR implements layer-pipelined KV transfer: instead of computing all layers then transferring all KV at once, we split layers into groups and transfer each group incrementally. Transfer of group N overlaps with GPU compute of group N+1, significantly reducing TTFT.

Related: #19931 (same direction, different approach)

Key Results

Benchmark environment: SGLang v0.4.10.post2, torch 2.7.1+cu126, Qwen2.5-72B-Instruct, 2×8 H20 GPUs (TP=4 each), 4×400G IB (RDMA), PD + Mooncake backend.
The pipelined KV transfer logic is algorithmically identical between v0.4.10.post2 and this PR — differences are limited to upstream API adaptation (environ.py registration, PP-aware pointers, EAGLE/staging guards). See "Code equivalence" section below.

TTFT (ms) — Prompt Length Sweep (C=32, output=256)

Prompt Baseline Pipelined Δ%
256 229 269 +18% (normal path, below threshold)
1024 696 274 -61%
4096 1192 378 -68%
8192 1234 637 -48%
16384 1070 922 -14%

TTFT p95 (ms)

Prompt Baseline Pipelined Δ%
1024 3378 589 -83%
4096 5553 859 -85%
8192 5344 1832 -66%
16384 6164 3029 -51%

Throughput (output tok/s)

Prompt Baseline Pipelined Δ%
1024 888 923 +4%
4096 723 858 +19%
16384 136 636 +367%

Multi-turn Dialogue (16 sessions × 10 turns)

Metric Baseline Pipelined Δ
Completed 160 160 same
Throughput 485 t/s 481 t/s -1%
TTFT avg 311 ms 314 ms +1%

Extreme Stress (C=64, prompt=4096, output=1024)

Metric Baseline Pipelined Δ
Completed reqs 128 128 same
TPOT avg 46.1 ms 46.0 ms 0%
Throughput 1352 t/s 1347 t/s 0%

Design

The feature is controlled by three environment variables (registered in `environ.py`), disabled by default:

  • `SGLANG_PIPELINED_KV_TRANSFER=true` — enable the feature
  • `SGLANG_PIPELINE_GROUP_SIZE=10` — layers per group (override adaptive default)
  • `SGLANG_PIPELINE_MIN_TOKENS=3072` — threshold; short prompts use normal path

How it works

  1. `_get_pipeline_group_size(batch)` — per-batch decision: returns adaptive group_size (>0) or 0 to skip pipeline. A universal guard ensures models without `forward_split_prefill` safely fallback to the normal path. Short prompts also fall back with zero overhead.

  2. `run_batch_pipelined(batch, group_size)` in `Scheduler` — splits forward into layer groups using `model_runner.forward_split_prefill()`, enqueues per-layer KV transfer via `send_layer()` after each group. CUDA events synchronize GPU→transfer ordering. Pre-computes state indices for hybrid models via `_prepare_pipelined_state_indices()`.

  3. `process_batch_result_pipelined_prefill()` — result handler that dispatches to `run_batch_pipelined` instead of `run_batch`, then follows the same downstream logic (including EAGLE spec_info propagation, staging sync, A2A MoE finalization).

  4. `MooncakeKVManager.send_kvcache_layer()` — single-layer RDMA transfer supporting both MHA and MLA architectures via `get_mha_kv_ptrs_with_pp` / `get_mla_kv_ptrs_with_pp`.

  5. `TpModelWorker.forward_batch_generation_split_{init,layer,sample}()` — three-phase split forward: init attention backend → run N layers per call → sample after last group.

Call chain

```
event_loop_normal_disagg_prefill
→ _get_pipeline_group_size(batch)
→ >0: run_batch_pipelined → split_init → [split_layer + send_layer] × N → split_sample
→ 0: run_batch (unchanged)
```

Adaptive group_size (E1)

Instead of a fixed `SGLANG_PIPELINE_GROUP_SIZE`, group_size is automatically tuned based on prompt length to keep pipeline iterations in [6, 10]:

  • Short prompts (<4K tokens): 10 iterations (maximize overlap)
  • Medium prompts (4K-8K): 8 iterations (sweet spot)
  • Long prompts (>8K): 6 iterations (reduce loop overhead)

User can still override via `SGLANG_PIPELINE_GROUP_SIZE` env var.

Different TP support (E2)

`send_kvcache_layer()` supports MHA head slicing when prefill TP ≠ decode TP, using vectorized numpy addressing (same math as `send_kvcache_slice`). MLA is TP-invariant and needs no slicing.

Mamba/SWA/NSA state support (E4)

Hybrid models (Jamba, FalconH1, DeepSeek-R1 with SWA) are fully supported. `_prepare_pipelined_state_indices()` pre-computes state indices before the layer loop, then passes them through `send_layer(state_indices=...)` on the last layer to trigger `maybe_send_extra()`. This covers:

  • HybridLinearKVPool (Mamba SSM): `req_index_to_mamba_index_mapping`
  • SWAKVPool (Sliding Window): windowed page indices via `translate_loc_from_full_to_swa`
  • NSATokenToKVPool (Native Sparse Attention): full sequence page indices

No decode-side changes needed — decode already waits for all data (KV + state) before starting.

Universal guard + FalconH1 support (E7)

A universal `hasattr(model, "forward_split_prefill")` guard replaces the previous multimodal-only guard. This ensures:

  • Models without `forward_split_prefill` (e.g. DeepSeek-V2/V3) safely fallback — no crash
  • Models with `forward_split_prefill` (LLaMA, Qwen, Gemma, FalconH1, etc.) use pipelined path

FalconH1 (Mamba hybrid) now has `forward_split_prefill`, enabling layer-pipelined transfer. Each layer's attention produces KV cache (transferred per-layer via pipeline), while SSM state is sent once at the end via `maybe_send_extra()` (SSM state is fixed-size, independent of sequence length — no benefit from per-layer pipelining).

MTP/EAGLE compatibility (E8)

Reviewed and confirmed that `process_batch_result_pipelined_prefill` correctly propagates EAGLE `spec_info` (`topk_p`, `topk_index`, `hidden_states`) to requests — identical to the normal path. MTP decode-side rollback is purely a decode-phase operation with no interaction with prefill-time pipelined transfer. Also aligned `copy_done.synchronize()`, `routed_experts_output.finalize()`, and `maybe_cache_unfinished_req` with the normal result handler.

Zero regression guarantee

When `SGLANG_PIPELINED_KV_TRANSFER=false` (default):

  • `_get_pipeline_group_size()` returns `0` on the first line
  • `run_batch_pipelined`, `process_batch_result_pipelined_prefill`, and all `split_*` methods are never called
  • `TransferKVChunk.layer_id` defaults to `None`, so `transfer_worker` always takes the existing path
  • `add_transfer_request` new parameters have default values — existing callers unaffected

Code equivalence (v0.4.10.post2 → this PR)

Benchmarks were collected on v0.4.10.post2. This PR ports the same logic to upstream main with these adaptations:

  • `os.environ.get()` → `envs.XXX.get()` (central environ.py registry)
  • `send_kvcache_layer` uses PP-aware `get_mha_kv_ptrs_with_pp` / `get_mla_kv_ptrs_with_pp` (upstream helpers, equivalent at PP=1)
  • `process_batch_result_pipelined_prefill` adapted for upstream's EAGLE spec info, staging buffer, and `report_prefill_stats` APIs
  • Core algorithm (grouped layer forward loop, CUDA event sync, per-layer RDMA enqueue) is identical

Checklist

# Item Status
1 GQA, same TP, no MTP ✅ Done (LLaMA, Qwen, Gemma verified)
2 Mamba hybrid (FalconH1), same TP ✅ Done — `forward_split_prefill` implemented, SSM state via `maybe_send_extra()`
3 GQA + Mamba, different TP ✅ Done (E2 head slicing + E4 state transfer)
4 MLA (DeepSeek-V2/V3) ⚠️ Transport layer ready (`send_kvcache_layer` MLA path), but DeepSeek models lack `forward_split_prefill` due to complex TBO/NSA/PP interactions — safely fallback via universal guard
5 MTP/EAGLE compatibility ✅ Done — `spec_info` propagation verified, result handler aligned with normal path
6 PP support ❌ Not started

Modified Files

File Changes
`sglang/srt/disaggregation/prefill.py` `_get_pipeline_group_size` (universal guard + adaptive), `_prepare_pipelined_state_indices`, pipelined event loop branch, `process_batch_result_pipelined_prefill` (with EAGLE/staging/A2A alignment)
`sglang/srt/managers/scheduler.py` `run_batch_pipelined` (with state_indices pass-through), info log in `dispatch_event_loop`
`sglang/srt/disaggregation/mooncake/conn.py` `send_kvcache_layer` (MHA+MLA+head slicing), `_send_kvcache_layer_head_slice`, `TransferKVChunk` extension, `transfer_worker` dispatch, `MooncakeKVSender.send_layer` (with state_indices)
`sglang/srt/disaggregation/fake/conn.py` `FakeKVSender.send_layer` stub (with state_indices)
`sglang/srt/managers/tp_worker.py` `forward_batch_generation_split_{init,layer,sample}`
`sglang/srt/models/falcon_h1.py` `FalconH1ForCausalLM.forward_split_prefill`
`sglang/srt/models/qwen3_5.py` `forward_split_prefill` for Qwen3.5 + VL (contributed by @UNIDY2002)
`sglang/srt/environ.py` Register 3 env vars
`docs/references/environment_variables.md` Document 3 env vars
`test/registered/disaggregation/test_disaggregation_pipelined.py` CI tests for pipelined transfer

Future Work

  • MLA `forward_split_prefill` — DeepSeek-V2/V3 has complex interactions with TBO, NSA context parallel, and A2A MoE. Deferred as separate PR by someone familiar with DeepSeek internals.
  • PP support — Cross-PP-stage pipeline coordination (long-term)

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@github-actions github-actions Bot added the documentation Improvements or additions to documentation label Apr 23, 2026
@ShangmingCai
Copy link
Copy Markdown
Collaborator

CC: @UNIDY2002 Could you check this? I haven't gone through this PR carefully yet, but this seems like a cleaner implementation.

@UNIDY2002
Copy link
Copy Markdown
Contributor

Nice work. We took a different approach in #19931 (callback-driven, per-layer notifications from inside HybridLinearAttnBackend.forward()), but your split_init/layer/sample + send_layer design is cleaner — the scheduler controls layer ranges and Mooncake just executes transfers, which is the right separation.

We've been working on Qwen3.5-397B-A17B (hybrid linear attention + GQA + VL), and there are a couple of gaps we can help fill:

  1. Qwen3.5 lacks forward_split_prefill — Upstream qwen3_5.py and its parent qwen3_vl.py only have forward(), so tp_worker.forward_batch_generation_split_layer()model_runner.forward_split_prefill() would fail for this model. We have a working implementation that handles capture_aux_hidden_states / _is_layer_to_capture for DFLASH. Happy to port it as a follow-up.

  2. Multimodal fallback — Qwen3.5-VL inherits from Qwen3VLForConditionalGeneration and needs general_mm_embed_routine for multimodal inputs, which is incompatible with split-prefill. It'd be useful to add a multimodal guard in _get_pipeline_group_size() so those batches fall back to the normal path. We hit this in our testing.

We'd like to collaborate on getting Qwen3.5 support into this PR (or a follow-up).

@michael7193
Copy link
Copy Markdown
Author

Thanks for the thoughtful review and kind words, @UNIDY2002!

Great to hear about your experience with #19931. The callback-driven approach is interesting — glad we converged on similar goals from different angles.

Both issues you raised are very practical:

forward_split_prefill for Qwen3.5 — Makes total sense. The current implementation assumes models provide forward_split_prefill, so hybrid models like Qwen3.5-397B-A17B would indeed need that. Would love to see your implementation — a follow-up PR sounds perfect.

Multimodal fallback — Good catch. Adding a multimodal guard in _get_pipeline_group_size() to fall back to the normal path is straightforward and the right thing to do. Happy to include it in this PR if you'd like to send a patch, or we can handle it in the follow-up together.

Very much looking forward to collaborating on Qwen3.5 support. Feel free to ping me anytime!

Comment thread python/sglang/srt/disaggregation/mooncake/conn.py Outdated
@michael7193 michael7193 force-pushed the feature/layer-pipelined-kv-transfer branch from 39d680d to 155f9b7 Compare May 6, 2026 06:46
@michael7193
Copy link
Copy Markdown
Author

@UNIDY2002 Thanks for the catch — applied your is_last_chunk fix and also resolved the lint issues (missing import + formatting in qwen3_5.py). All pre-commit checks are passing now. ✅

@ShangmingCai Gentle ping — this PR is ready for review whenever you have a chance. Summary of what's been done since your last look:

  • Rebased onto latest main (conflict resolved cleanly)
  • Merged UNIDY2002's Qwen3.5 forward_split_prefill support
  • Added model-aware multimodal guard (VL models with forward_split_prefill can still use pipelining)
  • Fixed is_last_chunk param name bug (UNIDY2002's suggestion)
  • Lint all green

Happy to address any further feedback!

root and others added 13 commits May 8, 2026 18:38
Overlap RDMA KV transfer with GPU compute by splitting prefill into
layer groups and enqueuing per-layer transfers after each group finishes.
Transfer(N) overlaps with compute(N+1), reducing TTFT by 14-68% for
long prompts (>=3K tokens) in production benchmarks.

Key changes:
- scheduler.py: add run_batch_pipelined() with grouped forward + KV send
- tp_worker.py: add split_init/split_layer/split_sample for layer-wise forward
- mooncake/conn.py: add send_kvcache_layer() for single-layer RDMA (MHA+MLA)
- prefill.py: per-batch dispatch via _should_use_pipelined() + result handler
- fake/conn.py: send_layer stub for warmup requests

Gated by SGLANG_PIPELINED_KV_TRANSFER=1 (default off). Configurable via
SGLANG_PIPELINE_GROUP_SIZE (default 10) and SGLANG_PIPELINE_MIN_TOKENS
(default 3072). Short prompts below threshold use the normal path.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Register SGLANG_PIPELINED_KV_TRANSFER, SGLANG_PIPELINE_GROUP_SIZE, and
SGLANG_PIPELINE_MIN_TOKENS in the central environ.py registry. Migrate
os.environ.get() calls to envs.XXX.get() for consistency with the rest
of the codebase. Add documentation to environment_variables.md.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Instead of a fixed SGLANG_PIPELINE_GROUP_SIZE, automatically compute
group_size to keep pipeline iterations in [6, 10] range:
- short prompts (<4K): 10 iterations (more overlap)
- medium prompts (4K-8K): 8 iterations (sweet spot)
- long prompts (>8K): 6 iterations (reduce loop overhead)

User can still override via SGLANG_PIPELINE_GROUP_SIZE env var.
Rename _should_use_pipelined -> _get_pipeline_group_size (returns 0
to skip, >0 for the group_size to use).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Extend send_kvcache_layer() with optional dst_tp_rank, dst_attn_tp_size,
and dst_kv_item_len parameters. When prefill TP != decode TP for MHA
models, apply per-token head slicing using vectorized numpy addressing
(same math as send_kvcache_slice). MLA remains TP-invariant.

Update transfer_worker dispatch to detect TP mismatch in the layer-
pipelined branch and forward the extra parameters.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Layer-pipelined KV transfer currently skips Mamba/SWA/NSA state
transfer (state is only sent on is_last_chunk, which the per-layer
path never triggers). This would cause silent data loss for hybrid
models like Jamba, FalconH1, and DeepSeek-R1 with SWA.

Add a safety guard in _get_pipeline_group_size() that falls back to
the normal (non-pipelined) path when state_type != "none". Per-layer
state pipelining will be implemented in a follow-up.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The layer-pipelined path was silently skipping Mamba/SWA/NSA state
transfer because send_layer() never passed state_indices. Fix:

1. Add state_indices param to MooncakeKVSender.send_layer() and
   FakeKVSender.send_layer(). On is_last=True, state_indices are
   forwarded to add_transfer_request(), which lets transfer_worker
   call maybe_send_extra() on the last chunk.

2. Add _prepare_pipelined_state_indices() in prefill.py that mirrors
   the state_indices computation from send_kv_chunk() and attaches
   the result to each req before the layer loop.

3. Remove the safety guard that forced Mamba/SWA/NSA models to fall
   back to the normal path — no longer needed.

State transfer still happens as a bulk operation on the last chunk
(not per-layer), but now overlaps with the last KV group transfer.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Skip layer-pipelined KV transfer when batch contains multimodal inputs
(images/audio), as split-prefill is incompatible with general_mm_embed_routine.
Falls back to normal path to avoid crashes or silent data corruption.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Xun Sun <UNIDY2002@outlook.com>
Add forward_split_prefill support for Qwen3.5-397B-A17B (hybrid linear attention + GQA + VL), enabling layer-pipelined KV transfer.

Changes:
- Qwen3_5ForCausalLM.forward_split_prefill: text model split forward
- Qwen3_5MoeForConditionalGeneration.forward_split_prefill: VL wrapper with general_mm_embed_routine
- Fix is_last -> is_last_chunk parameter name in conn.py

Tested on 2x8xH20 with Qwen3.5-397B-A17B-FP8, PD + Mooncake TCP.

Authored-by: UNIDY2002
…ward_split_prefill

Previously the guard blocked ALL multimodal batches from pipeline mode.
Now it only blocks when the model lacks forward_split_prefill (meaning it
can't handle multimodal inputs in split mode). Models like Qwen3.5-VL that
implement multimodal-aware forward_split_prefill are allowed through.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Add `general_mm_embed_routine` import from `sglang.srt.managers.mm_utils`
  to fix ruff F821 (undefined name) in `forward_split_prefill`
- Remove extra blank lines to satisfy black formatter

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add test_disaggregation_pipelined.py covering:
- GSM8K eval correctness with pipelined transfer enabled
- Basic single-request generation
- Long prompt exercising deeper pipeline overlap
- Concurrent request handling (8 parallel prefills)
- Fixed group_size configuration path

Tests enable SGLANG_PIPELINED_KV_TRANSFER=1 with a low min_tokens
threshold to exercise the pipeline path on standard CI eval prompts.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
root and others added 3 commits May 8, 2026 18:38
- Replace multimodal-only guard with universal hasattr check so that
  models without forward_split_prefill (e.g. Mamba/hybrid) safely
  fallback to the normal transfer path instead of crashing.
- Implement forward_split_prefill for FalconH1ForCausalLM, enabling
  layer-pipelined KV+state transfer for Mamba hybrid models.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Add copy_done.synchronize() for staging buffer correctness
- Add routed_experts_output/indexer_topk_output finalize() to prevent
  resource leak in A2A MoE configurations
- Use maybe_cache_unfinished_req() instead of direct cache_unfinished_req()
  to handle HiCache conditional logic correctly

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The forward_split_prefill method uses LogitsProcessorOutput in its
return type annotation but it was not imported, causing CI lint failure.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@michael7193 michael7193 force-pushed the feature/layer-pipelined-kv-transfer branch from b5267cb to 2547641 Compare May 9, 2026 01:38
@michael7193
Copy link
Copy Markdown
Author

@ShangmingCai Friendly ping — this PR has been rebased onto the latest main (no conflicts). Would appreciate your review when you get a chance.

Also, could a maintainer add the run-ci label so the GPU tests can run? Lint is passing. Thanks!

@ShangmingCai
Copy link
Copy Markdown
Collaborator

Great! Too busy lately, let me trigger the CI first, will start to review next week. Thank you so much for the PR.

@ShangmingCai
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@github-actions github-actions Bot added the run-ci label May 9, 2026
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file has a lint error. Also, is this modification mis-added by cc?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The falcon_h1.py change is intentional — FalconH1 is a Mamba/Attention hybrid model where SSM conv states need special handling during layer-pipelined transfer (sent once at the final group via maybe_send_extra()). I'll fix the lint error in the next push.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does is means that we need to impl this forward_split_prefill for every single model? This might not be a robust design. Will dive in next week.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question! Actually this is not a new pattern we're introducing — there are already 15 models in the upstream codebase that implement forward_split_prefill (llama, qwen, qwen2, qwen3, gemma, gemma2, gemma3, glm4, exaone4, sarvam_moe, qwen2_moe, qwen3_moe, etc.), added for chunked prefill / PP support.

Our layer-pipelined feature simply reuses this existing interface. The design has two layers of safety:

  1. Guard fallback: If a model doesn't have forward_split_prefill, the pipelined path is automatically skipped and the request goes through the normal path (no crash, no regression).
  2. Pattern is mechanical: For standard transformer models, the implementation is identical — embed → layers[start:end] → norm → logits. Only hybrid models (Mamba SSM, hybrid linear attention) need custom logic.

That said, if you'd prefer a more robust approach, we could add a default generic implementation in a base class that works for any standard transformer model, so new models get pipelined support for free without writing any code. Happy to explore that direction if you think it's worthwhile.

Run ruff format to fix lint errors flagged in review.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@michael7193
Copy link
Copy Markdown
Author

Fixed the falcon_h1.py formatting issue (commit 489bf9d). Could you re-trigger CI when you get a chance? Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants