Skip to content

[BugFix] Fix initialization of draft model. #29319

Merged
tlrmchlsmth merged 3 commits intovllm-project:mainfrom
halyavin:fix-mtp-comm
Nov 25, 2025
Merged

[BugFix] Fix initialization of draft model. #29319
tlrmchlsmth merged 3 commits intovllm-project:mainfrom
halyavin:fix-mtp-comm

Conversation

@halyavin
Copy link
Copy Markdown
Contributor

@halyavin halyavin commented Nov 24, 2025

This initialization is needed to make MTP for DeepSeek V3 work with high-throughput backend.

DeepSeek V3 draft model also has a MoE layer. The method DeviceCommunicationBase.prepare_communication_buffer_for_model calls FusedMoEMethodBase.init_prepare_finalize method which in turn sets fused_experts field of the layer. During MTP calculation without this field FusedMoE.forward_impl method sees that using_modular_kernel property is false and sets do_naive_dispatch_combine flag and as a consequence calls get_ep_group().dispatch(). But dispatch method is not implemented in DeepEPHTAll2AllManager class which throws an exception.

Calling prepare_communication_buffer_for_model on a draft model makes this exception go away and makes MTP working.

…ulations.

Signed-off-by: Andrey Khalyavin <halyavin@yandex-team.ru>
@mergify mergify bot added the v1 label Nov 24, 2025
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a bug related to the initialization of draft models with Mixture of Experts (MoE) layers, ensuring compatibility with the high-throughput backend. The fix involves correctly calling prepare_communication_buffer_for_model on the draft model. My review includes a suggestion to make the condition for this call more robust to prevent potential NoneType errors during model loading.

Signed-off-by: Andrey Khalyavin <halyavin@yandex-team.ru>
Comment on lines +3348 to +3349
if (drafter := getattr(self, "drafter", None)) and (drafter_model := getattr(drafter, 'model', None)):
prepare_communication_buffer_for_model(drafter_model)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need a way to share the All2All state between self.model and the drafter -- this may duplicate state

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bnellnm @varun-sundar-rabindranath could you take a look?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From an offline conversation - the DeepEP buffers will be cached, so this won't involve any extra state for those All2All backends. might not be the case for the FlashInfer All2Alls -- @pavanimajety would you know if this causes any overhead in that case?

I think we should go ahead and land this for now to get MTP working on main

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cache the all2all handles here for deepep high-throughput here

def get_handle(self, kwargs):

This however, hashes on some model and DP/EP properties like hidden_size, num_local_experts and num_global_experts. This means that if the draft model's properties differ from the base-model, which is likely, we will create a new all2all handle. But this is necessary.

I think this is good to land to unbreak main .

Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
@tlrmchlsmth tlrmchlsmth added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 25, 2025
@tlrmchlsmth tlrmchlsmth enabled auto-merge (squash) November 25, 2025 16:10
@tlrmchlsmth
Copy link
Copy Markdown
Member

Thanks for the fix @halyavin !

@tlrmchlsmth tlrmchlsmth merged commit de75b0b into vllm-project:main Nov 25, 2025
46 of 47 checks passed
@halyavin halyavin deleted the fix-mtp-comm branch November 26, 2025 08:43
devpatelio pushed a commit to SumanthRH/vllm that referenced this pull request Nov 29, 2025
Signed-off-by: Andrey Khalyavin <halyavin@yandex-team.ru>
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
kitaekatt pushed a commit to kitaekatt/vllm that referenced this pull request Dec 1, 2025
Signed-off-by: Andrey Khalyavin <halyavin@yandex-team.ru>
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
Signed-off-by: Andrey Khalyavin <halyavin@yandex-team.ru>
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants