[feature][PD disaggregation] KV cache transfer for speculative decoding with different target/draft model architectures (MLA and MHA combinations)#20698
Conversation
fix rebase error for pd transfer Reviewed By: duzeyan CR Link: https://kunpeng.xiaojukeji.com/view/revision/5827599
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the speculative decoding capabilities within the disaggregated system by introducing robust support for scenarios where target and draft models utilize different KV cache architectures, specifically MLA and MHA combinations. The changes involve separating KV cache management for target and draft models, introducing flags to identify draft model architecture, and refactoring the data transfer logic to handle these mixed configurations efficiently. This ensures correct and optimized KV cache transmission across different ranks, particularly when dealing with models like deepseek that might use a Llama draft model. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
Can you share the acceptance rate of your case? The suitable scenarios sound very niche. |
There was a problem hiding this comment.
Code Review
This pull request introduces comprehensive changes to support KV cache transfer for speculative decoding with mixed target/draft model architectures (MLA and MHA combinations) and DP-Attention. The changes involve separating KV data structures for target and draft models, adding flags to distinguish draft model architecture, and refactoring the transfer logic to handle these new complexities. The introduction of the is_send_target flag for selective KV data transfer is a key improvement for optimizing communication. Overall, the refactoring improves flexibility and addresses the specific needs of mixed model architectures in a disaggregated environment.
| dst_state_item_lens=( | ||
| list(struct.unpack(f"{len(msg[10])//4}I", msg[10])) | ||
| if len(msg) > 10 and len(msg[10]) > 0 | ||
| list(struct.unpack(f"{len(msg[13])//4}I", msg[13])) | ||
| if len(msg) > 10 and len(msg[13]) > 0 | ||
| else [] | ||
| ), | ||
| dst_state_dim_per_tensor=( | ||
| list(struct.unpack(f"{len(msg[11])//4}I", msg[11])) | ||
| if len(msg) > 11 and len(msg[11]) > 0 | ||
| list(struct.unpack(f"{len(msg[14])//4}I", msg[14])) | ||
| if len(msg) > 11 and len(msg[14]) > 0 | ||
| else [] |
There was a problem hiding this comment.
The conditions for unpacking dst_state_item_lens and dst_state_dim_per_tensor in KVArgsRegisterInfo.from_zmq appear to be incorrect. After adding draft_kv_ptrs, draft_dst_kv_item_len, and is_send_target (which occupy indices 10, 11, and 12 respectively), the dst_state_item_lens should now be at index 13 and dst_state_dim_per_tensor at index 14.
Therefore, the length checks len(msg) > 10 and len(msg) > 11 should be updated to len(msg) > 13 and len(msg) > 14 respectively, to ensure that these fields are only accessed if they are actually present in the message. Otherwise, it might lead to an IndexError if the message is shorter than expected.
| dst_state_item_lens=( | |
| list(struct.unpack(f"{len(msg[10])//4}I", msg[10])) | |
| if len(msg) > 10 and len(msg[10]) > 0 | |
| list(struct.unpack(f"{len(msg[13])//4}I", msg[13])) | |
| if len(msg) > 10 and len(msg[13]) > 0 | |
| else [] | |
| ), | |
| dst_state_dim_per_tensor=( | |
| list(struct.unpack(f"{len(msg[11])//4}I", msg[11])) | |
| if len(msg) > 11 and len(msg[11]) > 0 | |
| list(struct.unpack(f"{len(msg[14])//4}I", msg[14])) | |
| if len(msg) > 11 and len(msg[14]) > 0 | |
| else [] | |
| dst_state_item_lens=( | |
| list(struct.unpack(f"{len(msg[13])//4}I", msg[13])) | |
| if len(msg) > 13 and len(msg[13]) > 0 | |
| else [] | |
| ), | |
| dst_state_dim_per_tensor=( | |
| list(struct.unpack(f"{len(msg[14])//4}I", msg[14])) | |
| if len(msg) > 14 and len(msg[14]) > 0 | |
| else [] | |
| ), |
|
Hi @yz-tang Are you still driving this? Happy to pick up. |
Motivation
When using EAGLE3 for deepseek, because my Draft Model is llama, it requires supporting mixed transmission of MLA and MHA when transfer kvcache. If DP-Attention is enabled in this case, we need to consider how to send and receive this data at different DP ranks.
@duzeyan @yz-tang
Modifications
Separated target and draft model KV data structures with dedicated pointers
Added
is_draft_mla_backendflag to distinguish draft model architecture typeIntroduced
is_send_targetflag for selective KV data transfer to avoid redundant sendsRefactored KV cache transfer logic to handle both target and draft models independently
MLA Target + MHA Draft: Only rank 0 transfers MLA KV data; all ranks transfer MHA draft KV
MLA Target + MLA Draft: Single rank handles all KV transfer with dummy connections for others
Bootstrap Protocol: Extended to include draft KV pointers and selective send flags
Response Counting: Correctly calculates
required_prefill_response_numbased on model typesAccuracy Tests
Benchmarking and Profiling
Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci