Skip to content

[feature][PD disaggregation] KV cache transfer for speculative decoding with different target/draft model architectures (MLA and MHA combinations)#20698

Open
yz-tang wants to merge 4 commits intosgl-project:mainfrom
yz-tang:tyz_eagle3
Open

Conversation

@yz-tang
Copy link
Copy Markdown

@yz-tang yz-tang commented Mar 16, 2026

Motivation

When using EAGLE3 for deepseek, because my Draft Model is llama, it requires supporting mixed transmission of MLA and MHA when transfer kvcache. If DP-Attention is enabled in this case, we need to consider how to send and receive this data at different DP ranks.
@duzeyan @yz-tang

Modifications

  • Separated target and draft model KV data structures with dedicated pointers

  • Added is_draft_mla_backend flag to distinguish draft model architecture type

  • Introduced is_send_target flag for selective KV data transfer to avoid redundant sends

  • Refactored KV cache transfer logic to handle both target and draft models independently

  • MLA Target + MHA Draft: Only rank 0 transfers MLA KV data; all ranks transfer MHA draft KV

  • MLA Target + MLA Draft: Single rank handles all KV transfer with dummy connections for others

  • Bootstrap Protocol: Extended to include draft KV pointers and selective send flags

  • Response Counting: Correctly calculates required_prefill_response_num based on model types

Accuracy Tests

Benchmarking and Profiling

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the speculative decoding capabilities within the disaggregated system by introducing robust support for scenarios where target and draft models utilize different KV cache architectures, specifically MLA and MHA combinations. The changes involve separating KV cache management for target and draft models, introducing flags to identify draft model architecture, and refactoring the data transfer logic to handle these mixed configurations efficiently. This ensures correct and optimized KV cache transmission across different ranks, particularly when dealing with models like deepseek that might use a Llama draft model.

Highlights

  • KV Data Structure Separation: Introduced distinct data structures (draft_kv_data_ptrs, draft_kv_data_lens, draft_kv_item_lens) within KVArgs to manage KV cache for draft models independently from target models.
  • Draft Model Architecture Flag: Added is_draft_mla_backend flag to differentiate between MLA (Multi-Layer Attention) and MHA (Multi-Head Attention) architectures for draft models, enabling architecture-specific handling during KV cache transfer.
  • Selective KV Data Transfer: Implemented an is_send_target flag in the bootstrap protocol to control whether target model KV data needs to be sent by a specific rank, optimizing transfer in mixed MLA/MHA scenarios (e.g., only rank 0 sends MLA KV data for target model when draft is MHA).
  • Refactored KV Cache Transfer Logic: The core KV cache transfer mechanism was refactored to handle both target and draft models separately, adapting to various combinations of MLA and MHA architectures for both models.
  • Updated Response Counting: Adjusted the calculation of required_prefill_response_num to correctly account for different target and draft model types, especially in MLA configurations.
  • Bootstrap Protocol Extension: Extended the bootstrap information (KVArgsRegisterInfo) to include draft KV pointers and the is_send_target flag, facilitating proper setup for disaggregated speculative decoding.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/disaggregation/base/conn.py
    • KVArgs: Added draft_kv_data_ptrs, draft_kv_data_lens, draft_kv_item_lens for draft model KV cache data, and draft_kv_head_num for draft model KV head count.
    • CommonKVManager.init: Added is_draft_mla_backend parameter to the constructor.
  • python/sglang/srt/disaggregation/common/conn.py
    • CommonKVManager.init: Initialized self.is_draft_mla_backend and passed it to the superclass constructor.
    • _init_kv_manager: Modified the logic for setting required_prefill_response_num to consider is_draft_mla_backend and the presence of draft KV data.
    • _setup_bootstrap_infos: Updated the logic for setting is_dummy and introduced is_send_target flag based on MLA/MHA combinations for selective KV data transfer.
  • python/sglang/srt/disaggregation/decode.py
    • DecodeWorker.init: Initialized self.is_draft_mla_backend based on the draft token KV pool.
    • _init_kv_manager: Separated the assignment of target and draft KV data pointers/lengths to kv_args and added initialization for empty draft KV lists if no draft model is present.
    • _init_kv_manager: Passed self.is_draft_mla_backend to the CommonKVManager constructor.
    • _send_kvcache_generic: Added is_mla_backend as a parameter and updated its usage.
    • send_kvcache: Added is_mla_backend parameter and passed it to _send_kvcache_generic.
    • send_kvcache_slice: Modified parameters to include src_kv_ptrs, src_kv_item_len, and num_kv_heads for more flexible KV slice transfer.
    • _send_kv_cache: Introduced a new helper method to encapsulate the logic for sending either target or draft KV cache, handling both generic and slice-based transfers.
    • transfer_worker: Refactored the KV cache sending logic to utilize the new _send_kv_cache method for both target and draft models, incorporating is_send_target and error handling.
    • maybe_send_extra: Updated logic to include checks for is_mla_backend and is_send_target for hybrid models.
    • _register_kv_args: Extended the bootstrap message to include draft_packed_kv_data_ptrs, draft_dst_kv_item_len, and is_send_target_str.
  • python/sglang/srt/disaggregation/mooncake/conn.py
    • KVArgsRegisterInfo: Added draft_kv_ptrs, draft_dst_kv_item_len, and is_send_target fields.
    • KVArgsRegisterInfo.from_zmq: Updated to unpack the new draft KV and is_send_target fields from the ZMQ message.
    • MooncakeKVManager.init: Passed is_draft_mla_backend to the superclass constructor.
    • register_buffer_to_engine: Added logic to batch register draft KV data buffers if present.
    • _send_kvcache_generic: Modified to accept is_mla_backend as a parameter and use it for conditional logic.
    • send_kvcache_slice: Updated to use src_kv_ptrs instead of self.kv_args.kv_data_ptrs for flexibility.
  • python/sglang/srt/disaggregation/prefill.py
    • PrefillWorker.init: Initialized self.is_draft_mla_backend based on the draft token KV pool.
    • _init_kv_manager: Separated the assignment of target and draft KV data pointers/lengths to kv_args, added initialization for empty draft KV lists if no draft model is present, and set draft_kv_head_num if applicable.
    • _init_kv_manager: Passed self.is_draft_mla_backend to the CommonKVManager constructor.
Activity
  • No human activity (comments, reviews, etc.) has been recorded on this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@ShangmingCai ShangmingCai self-assigned this Mar 16, 2026
@ShangmingCai
Copy link
Copy Markdown
Collaborator

Can you share the acceptance rate of your case? The suitable scenarios sound very niche.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces comprehensive changes to support KV cache transfer for speculative decoding with mixed target/draft model architectures (MLA and MHA combinations) and DP-Attention. The changes involve separating KV data structures for target and draft models, adding flags to distinguish draft model architecture, and refactoring the transfer logic to handle these new complexities. The introduction of the is_send_target flag for selective KV data transfer is a key improvement for optimizing communication. Overall, the refactoring improves flexibility and addresses the specific needs of mixed model architectures in a disaggregated environment.

Comment on lines 141 to 149
dst_state_item_lens=(
list(struct.unpack(f"{len(msg[10])//4}I", msg[10]))
if len(msg) > 10 and len(msg[10]) > 0
list(struct.unpack(f"{len(msg[13])//4}I", msg[13]))
if len(msg) > 10 and len(msg[13]) > 0
else []
),
dst_state_dim_per_tensor=(
list(struct.unpack(f"{len(msg[11])//4}I", msg[11]))
if len(msg) > 11 and len(msg[11]) > 0
list(struct.unpack(f"{len(msg[14])//4}I", msg[14]))
if len(msg) > 11 and len(msg[14]) > 0
else []
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The conditions for unpacking dst_state_item_lens and dst_state_dim_per_tensor in KVArgsRegisterInfo.from_zmq appear to be incorrect. After adding draft_kv_ptrs, draft_dst_kv_item_len, and is_send_target (which occupy indices 10, 11, and 12 respectively), the dst_state_item_lens should now be at index 13 and dst_state_dim_per_tensor at index 14.

Therefore, the length checks len(msg) > 10 and len(msg) > 11 should be updated to len(msg) > 13 and len(msg) > 14 respectively, to ensure that these fields are only accessed if they are actually present in the message. Otherwise, it might lead to an IndexError if the message is shorter than expected.

Suggested change
dst_state_item_lens=(
list(struct.unpack(f"{len(msg[10])//4}I", msg[10]))
if len(msg) > 10 and len(msg[10]) > 0
list(struct.unpack(f"{len(msg[13])//4}I", msg[13]))
if len(msg) > 10 and len(msg[13]) > 0
else []
),
dst_state_dim_per_tensor=(
list(struct.unpack(f"{len(msg[11])//4}I", msg[11]))
if len(msg) > 11 and len(msg[11]) > 0
list(struct.unpack(f"{len(msg[14])//4}I", msg[14]))
if len(msg) > 11 and len(msg[14]) > 0
else []
dst_state_item_lens=(
list(struct.unpack(f"{len(msg[13])//4}I", msg[13]))
if len(msg) > 13 and len(msg[13]) > 0
else []
),
dst_state_dim_per_tensor=(
list(struct.unpack(f"{len(msg[14])//4}I", msg[14]))
if len(msg) > 14 and len(msg[14]) > 0
else []
),

@yxs
Copy link
Copy Markdown

yxs commented Apr 19, 2026

Hi @yz-tang Are you still driving this? Happy to pick up.
cc @ShangmingCai

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants