Skip to content

[Draft] feat: Mooncake support layerwise kv cache transfer#19931

Open
zhangxiaolei123456 wants to merge 4 commits intosgl-project:mainfrom
bytedance-iaas:main_qwen3.5_0305_per_layer
Open

[Draft] feat: Mooncake support layerwise kv cache transfer#19931
zhangxiaolei123456 wants to merge 4 commits intosgl-project:mainfrom
bytedance-iaas:main_qwen3.5_0305_per_layer

Conversation

@zhangxiaolei123456
Copy link
Copy Markdown
Contributor

@zhangxiaolei123456 zhangxiaolei123456 commented Mar 5, 2026

Motivation

Co-authored-by: UNIDY2002

  • Per-layer transfer support GQA and Mamba kv cache, same TP and without MTP
  • Per-layer transfer support GQA and Mamba kv cache, same TP and with MTP
  • Per-layer transfer support GQA and Mamba kv cache, different TP and with MTP
  • Per-layer transfer support MLA kv cache, different TP and with MTP
  • support others parallel(PP or others)

Modifications

Models: Qwen3.5

Prefill

SGLANG_ASYNC_KV_MISSING_WAIT_MS=100 SGLANG_ASYNC_KV_GQA_PER_LAYER_EVENT_SYNC=1 SGLANG_ASYNC_KV_MAMBA_PER_LAYER_EVENT_SYNC=1 SGLANG_MOONCAKE_ASYNC_KV=1 GLOO_SOCKET_IFNAME=eth0 NCCL_MIN_NCHANNELS=24 NCCL_IB_QPS_PER_CONNECTION=8  SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=600 SGLANG_DISAGGREGATION_THREAD_POOL_SIZE=128 SGLANG_DISAGGREGATION_QUEUE_SIZE=128  python -m sglang.launch_server --model-path /data00/models/Qwen3.5-397B-A17B-FP8 --port 8000 --tp-size 8 --mem-fraction-static 0.85 --reasoning-parser qwen3 --tool-call-parser qwen3_coder --mamba-ssm-dtype float16 --kv-cache-dtype fp8_e4m3 --disaggregation-mode prefill  --disaggregation-ib-device  "mlx5_1,mlx5_2,mlx5_3,mlx5_4" --host 0.0.0.0 --port 30300 --disable-radix-cache --max-running-requests 64  --chunked-prefill-size 0 --max-prefill-tokens 16384 --page-size 64

Decode without MTP

SGLANG_ASYNC_KV_GQA_PER_LAYER_EVENT_SYNC=1 
SGLANG_ASYNC_KV_MAMBA_PER_LAYER_EVENT_SYNC=1 SGLANG_MOONCAKE_ASYNC_KV=1 GLOO_SOCKET_IFNAME=eth0 NCCL_MIN_NCHANNELS=24 NCCL_IB_QPS_PER_CONNECTION=8 SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK=128 SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=600 SGLANG_DISAGGREGATION_THREAD_POOL_SIZE=128 SGLANG_DISAGGREGATION_QUEUE_SIZE=128  python -m sglang.launch_server --model-path /data00/models/Qwen3.5-397B-A17B-FP8 --port 8000 --tp-size 8 --ep-size 8 --mem-fraction-static 0.75 --context-length 131072 --reasoning-parser qwen3 --tool-call-parser qwen3_coder --cuda-graph-bs 1 8 16 32 64 --max-running-requests 256 --mamba-ssm-dtype float16 --kv-cache-dtype fp8_e4m3 --disaggregation-mode decode  --disaggregation-ib-device  "mlx5_1,mlx5_2,mlx5_3,mlx5_4" --moe-runner-backend deep_gemm --moe-a2a-backend deepep --deepep-mode low_latency --host 0.0.0.0 --port 30300 --enable-metrics --disable-radix-cache --page-size 64

Decode with MTP

SGLANG_ASYNC_KV_GQA_PER_LAYER_EVENT_SYNC=1 
SGLANG_ASYNC_KV_MAMBA_PER_LAYER_EVENT_SYNC=1 SGLANG_MOONCAKE_ASYNC_KV=1 GLOO_SOCKET_IFNAME=eth0 NCCL_MIN_NCHANNELS=24 NCCL_IB_QPS_PER_CONNECTION=8 SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK=128 SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=600 SGLANG_DISAGGREGATION_THREAD_POOL_SIZE=128 SGLANG_DISAGGREGATION_QUEUE_SIZE=128  python -m sglang.launch_server --model-path /data00/models/Qwen3.5-397B-A17B-FP8 --port 8000 --tp-size 8 --ep-size 8 --mem-fraction-static 0.75 --context-length 131072 --reasoning-parser qwen3 --tool-call-parser qwen3_coder --cuda-graph-bs 1 8 16 32 64 --max-running-requests 256 --mamba-ssm-dtype float16 --kv-cache-dtype fp8_e4m3 --disaggregation-mode decode  --disaggregation-ib-device  "mlx5_1,mlx5_2,mlx5_3,mlx5_4" --moe-runner-backend deep_gemm --moe-a2a-backend deepep --deepep-mode low_latency --host 0.0.0.0 --port 30300 --enable-metrics --disable-radix-cache --page-size 64 --speculative-algo EAGLE --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4

Accuracy Tests

Dataset: gsm8k
PD without MTP

python3 bench_sglang.py --host http://localhost  --port 8090 --data-path /data00 --num-questions 500 --parallel 10
100%|██████████████| 500/500 [03:16<00:00,  2.54it/s]
Accuracy: 0.950
Invalid: 0.016
Latency: 196.941 s
Output throughput: 413.301 token/s

PD with MTP

 python3 bench_sglang.py --host http://localhost  --port 8090 --data-path /data00 --num-questions 500 --parallel 10
100%|███████████████████████████| 500/500 [01:52<00:00,  4.43it/s]
Accuracy: 0.944
Invalid: 0.018
Latency: 112.831 s
Output throughput: 710.290 token/s

Benchmarking and Profiling

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant enhancement to the Mooncake disaggregation system by enabling asynchronous, per-layer KV cache transfer. This change aims to optimize the data transfer pipeline, particularly benefiting models with architectures like GQA and Mamba by allowing KV and state tensors to be transferred as soon as they are ready, rather than waiting for an entire batch. This approach leverages CUDA events for fine-grained synchronization, potentially reducing latency and improving overall throughput in disaggregated inference setups.

Highlights

  • Asynchronous KV Cache Transfer: Introduced MooncakeAsyncKVManager to enable asynchronous, per-layer KV cache transfer, improving performance and efficiency for specific model architectures.
  • Per-Layer Synchronization: Implemented per-layer event synchronization using CUDA events for both GQA (KV cache) and Mamba (state tensors) to ensure data readiness during asynchronous transfers.
  • Mamba State Tensor Handling: Added explicit support for transferring Mamba state tensors per-layer, including mechanisms to track and resend missing state tensors if needed.
  • Dynamic KV Manager Selection: Updated the prefill process to dynamically select the MooncakeAsyncKVManager when the SGLANG_MOONCAKE_ASYNC_KV environment variable is enabled.
  • Scheduler and Attention Backend Integration: Modified the scheduler to prepare batches for asynchronous KV transfer and the hybrid attention backend to trigger layer-ready callbacks during forward passes.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/disaggregation/mooncake/async_kv_manager.py
    • Added new file implementing MooncakeAsyncKVManager for asynchronous KV cache transfer.
    • Implemented StreamAsyncSubmitter for managing asynchronous submission of transfer tasks.
    • Added logic for per-layer event synchronization for GQA and Mamba state tensors.
    • Included functionality to detect and resend missing state tensors for robustness.
  • python/sglang/srt/disaggregation/prefill.py
    • Updated _init_kv_manager to conditionally instantiate MooncakeAsyncKVManager based on SGLANG_MOONCAKE_ASYNC_KV environment variable.
  • python/sglang/srt/environ.py
    • Added SGLANG_MOONCAKE_ASYNC_KV as a new environment variable to enable or disable asynchronous KV cache transfer.
  • python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py
    • Initialized Mamba layer and state tensor counts within the attention backend.
    • Modified the forward method to trigger layer_ready_callback for KV and Mamba state tensors during save_kv_cache operations.
  • python/sglang/srt/managers/scheduler.py
    • Imported set_layer_ready_callback for asynchronous KV management.
    • Updated run_batch to prepare batches for asynchronous KV transfer and set the layer-ready callback if the async feature is enabled and conditions are met.
  • python/sglang/srt/model_executor/forward_batch_info.py
    • Added layer_ready_callback and async_kv_batch_started fields to the ForwardBatch dataclass to support asynchronous KV transfer callbacks.
  • python/sglang/srt/model_executor/model_runner.py
    • Imported get_layer_ready_callback for asynchronous KV management.
    • Modified get_spec_info and forward methods to pass the layer_ready_callback to the ForwardBatch instance.
Activity
  • The pull request is currently in a 'Draft' state, indicating ongoing development.
  • The author has provided detailed motivation for supporting per-layer KV cache transfer for various model types and configurations.
  • Example bash commands for both prefill and decode modes have been included to demonstrate the usage of the new feature.
  • An accuracy test result (0.950 accuracy, 0.016 invalid, 196.941s latency, 413.301 token/s throughput) has been provided.
  • The checklist indicates that code formatting, unit tests, documentation, and further benchmarking are still pending.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@zhangxiaolei123456 zhangxiaolei123456 changed the title [Draft] feat: Mooncake support per-layer kv cache transfer [Draft] feat: Mooncake support per-layers kv cache transfer Mar 5, 2026
@zhangxiaolei123456 zhangxiaolei123456 changed the title [Draft] feat: Mooncake support per-layers kv cache transfer [Draft] feat: Mooncake support per-layer kv cache transfer Mar 5, 2026
@zhangxiaolei123456 zhangxiaolei123456 changed the title [Draft] feat: Mooncake support per-layer kv cache transfer [Draft] feat: Mooncake support layerwise kv cache transfer Mar 5, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces asynchronous, per-layer KV cache transfer for Mooncake, which is a significant feature for improving performance in disaggregated serving. The implementation uses an asynchronous submission mechanism and hooks into the model's forward pass to trigger layer-wise transfers, which is a solid approach to overlap computation and communication. The code is generally well-structured, but my review includes suggestions to improve robustness and maintainability, particularly around error handling and code duplication.

Note: Security Review did not run due to the size of the PR.

Comment on lines +173 to +180
except Exception as e:
import traceback

traceback.print_exc()
logger.info(f"Error in put_kvcache_thread: {e}")
import os

os._exit(1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using os._exit(1) in a worker thread is unsafe for a server application. It terminates the entire process abruptly, bypassing cleanup handlers, which can lead to resource leaks, corrupted state, and difficult debugging. A more graceful shutdown should be implemented, for example by signaling the main thread. For now, I'll suggest replacing this with proper exception logging to avoid crashing the whole server process on a single thread's error.

Suggested change
except Exception as e:
import traceback
traceback.print_exc()
logger.info(f"Error in put_kvcache_thread: {e}")
import os
os._exit(1)
except Exception:
logger.exception("Unhandled exception in _put_kvcache_func worker thread.")

Comment on lines +258 to +264
try:
import torch

if torch.cuda.is_available():
event.synchronize()
except Exception:
pass
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Swallowing exceptions with a broad except Exception: pass is risky. If event.synchronize() fails, it will be silently ignored, which could lead to race conditions or hard-to-debug issues. The exception should be logged to provide visibility into potential problems.

                                    except Exception as e:
                                        logger.warning(f"Failed to synchronize CUDA event: {e}")

Comment on lines +742 to +750
for field in vars(mamba_cache):
if field in ("intermediate_ssm", "intermediate_conv_window"):
continue
value = getattr(mamba_cache, field)
if isinstance(value, list):
state_tensors.extend(value)
else:
state_tensors.append(value)
self._mamba_state_tensors_per_layer = len(state_tensors)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of code for counting mamba state tensors appears to be duplicated in python/sglang/srt/disaggregation/mooncake/async_kv_manager.py (lines 698-709). To improve maintainability and reduce redundancy, this logic should be extracted into a shared helper function.

@stmatengss
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@github-actions github-actions Bot added the run-ci label Mar 6, 2026
@dongyibo
Copy link
Copy Markdown

@zhangxiaolei123456 Hello, could you please provide the performance test data? I understand that enabling this feature will reduce transmission time.

@UNIDY2002
Copy link
Copy Markdown
Contributor

UNIDY2002 commented Apr 10, 2026

I re-ran the PD layerwise-KV pipeline experiment under TCP transport (MC_FORCE_TCP=true), and the TTFT reduction is real.

Qwen3.5-397B-A17B-FP8 results:

Input Output Load Async OFF TTFT Async ON TTFT Delta
3500 1 max_concurrency=1 370.74 ms 316.69 ms -54.05 ms (-14.6%)
3500 1 request_rate=1.5, max_concurrency=32 457.69 ms 424.71 ms -32.98 ms (-7.2%)
16000 1 max_concurrency=1 1074.93 ms 913.47 ms -161.46 ms (-15.0%)

For attribution, the runtime measurements are also consistent on the 16000x1, c=1 case: in sync mode, prefill compute reaches transfer submission at about 838 ms, and the critical-path sequential TCP transfer tail is about 167 ms, plus some fixed downstream overhead. With async enabled, that transfer tail is overlapped into prefill, which is why the measured TTFT gain is large on TCP. The effect is expected to be smaller on RDMA because the transfer tail itself is smaller there.

Overall, I think this PR is worth looking into, especially for deployments where only TCP transport is available, since that is exactly where the overlap benefit is the clearest.

@ShangmingCai
Copy link
Copy Markdown
Collaborator

I re-ran the PD layerwise-KV pipeline experiment under TCP transport (MC_FORCE_TCP=true), and the TTFT reduction is real.

Qwen3.5-397B-A17B-FP8 results:

Input Output Load Async OFF TTFT Async ON TTFT Delta
3500 1 max_concurrency=1 370.74 ms 316.69 ms -54.05 ms (-14.6%)
3500 1 request_rate=1.5, max_concurrency=32 457.69 ms 424.71 ms -32.98 ms (-7.2%)
16000 1 max_concurrency=1 1074.93 ms 913.47 ms -161.46 ms (-15.0%)
For attribution, the runtime measurements are also consistent on the 16000x1, c=1 case: in sync mode, prefill compute reaches transfer submission at about 838 ms, and the critical-path sequential TCP transfer tail is about 167 ms, plus some fixed downstream overhead. With async enabled, that transfer tail is overlapped into prefill, which is why the measured TTFT gain is large on TCP. The effect is expected to be smaller on RDMA because the transfer tail itself is smaller there.

Overall, I think this PR is worth looking into, especially for deployments where only TCP transport is available, since that is exactly where the overlap benefit is the clearest.

@UNIDY2002 yeah, I agree. Did you test this PR under extreme heavy workload?

Comment thread python/sglang/srt/managers/scheduler.py Outdated
Comment on lines +2331 to +2332
try:
if self.is_generation:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm just not quite sure about this part. worried that it will somehow make the scheduler code hard to read and maintain (and hard to debug? because of the try block here), even when async transfer is not enabled. Is there a cleaner solution to achieve this?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. I think we need to improve the PR's implementation.

@UNIDY2002
Copy link
Copy Markdown
Contributor

@UNIDY2002 yeah, I agree. Did you test this PR under extreme heavy workload?

Will test with heavier workloads.

@UNIDY2002
Copy link
Copy Markdown
Contributor

Thanks for the suggestion. We did run an extreme heavy-load comparison with matched config.

Setup:

  • Qwen3.5-397B-A17B-FP8, PD + TCP
  • input_len=16000, output_len=1, num_prompts=128, max_concurrency=32
  • qps=0.25/1.0/1.25/1.5
  • baseline config: max_prefill_tokens=16384, chunked_prefill_size=0

Result (mean TTFT):

  • qps=0.25: OFF 1295.31 ms, ON 1135.11 ms (-12.37%)
  • qps=1.0: OFF 3845.58 ms, ON 3530.48 ms (-8.19%)
  • qps=1.25: OFF 10081.66 ms, ON 9725.94 ms (-3.53%)
  • qps=1.5: OFF 15996.56 ms, ON 15617.57 ms (-2.37%)

So async-ON is consistently better under this heavy long-prompt workload.
At high QPS, queueing dominates (prefill/input-throughput-limited regime), so relative gain shrinks even though async still improves the critical path.

@github-actions github-actions Bot added documentation Improvements or additions to documentation quant LLM Quantization amd dependencies Pull requests that update a dependency file lora Multi-modal multi-modal language model deepseek speculative-decoding hicache Hierarchical Caching for SGLang sgl-kernel blackwell SM100/SM120 npu piecewise-cuda-graph diffusion SGLang Diffusion labels Apr 20, 2026
zhangxiaolei123456 and others added 4 commits April 20, 2026 11:26
Co-authored-by: UNIDY2002 <unidy2002@outlook.com>
These flags are effectively fixed to defaults; remove dead env parsing and stale comments.
Implement layerwise async KV as an optional path in MooncakeKVManager and expose a standard scheduler hook, removing the async-manager subclass and ad-hoc capability checks.
Move the layerwise async Mooncake transfer state and helpers into a dedicated mixin so the base KV manager stays smaller while preserving the validated async overlap path.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

amd blackwell SM100/SM120 deepseek dependencies Pull requests that update a dependency file diffusion SGLang Diffusion documentation Improvements or additions to documentation hicache Hierarchical Caching for SGLang jit-kernel lora model-gateway mthreads Multi-modal multi-modal language model npu piecewise-cuda-graph quant LLM Quantization run-ci sgl-kernel speculative-decoding

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants