Skip to content

[feat] Add P/D attention select for draft model#9755

Merged
zhyncs merged 11 commits intosgl-project:mainfrom
Ximingwang-09:pd_attention_select_for_draft
Sep 3, 2025
Merged

[feat] Add P/D attention select for draft model#9755
zhyncs merged 11 commits intosgl-project:mainfrom
Ximingwang-09:pd_attention_select_for_draft

Conversation

@Ximingwang-09
Copy link
Copy Markdown
Contributor

@Ximingwang-09 Ximingwang-09 commented Aug 28, 2025

Motivation

#9573 has already enabled compatibility of Hybrid Attention in Speculative Decoding. However, for the draft model, the functionality to select the attention backend separately for the extend mode and decode mode has not been fully implemented—this PR supplements and completes the implementation of this feature.

Modifications

Accuracy Tests

Launch the server

python3 -m sglang.launch_server --model /mnt/Qwen3-8B --trust-remote-code --tp-size 8 --enable-cache-report --dtype bfloat16 --log-level info  --max-running-requests 16 --mem-fraction-static 0.85 --host 0.0.0.0 --port 9122 --speculative-algorithm EAGLE3 --speculative-num-steps 5 --speculative-eagle-topk 2 --speculative-num-draft-tokens 8 --speculative-draft-model-path /mnt/qwen3_8b_eagle3 --prefill-attention-backend fa3 --decode-attention-backend flashinfer

Test accuracy

python3 benchmark/gsm8k/bench_sglang.py  --num-shots 8 --num-questions 1319 --parallel 1319 --port 9122

Accuracy: 0.901

Benchmarking and Profiling

Checklist

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @Ximingwang-09, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request completes the implementation for selecting distinct attention backends for the draft model within speculative decoding, specifically allowing separate configurations for its prefill (extend) and decode phases. This enhances the flexibility and performance tuning capabilities for the draft model's attention mechanisms.

Highlights

  • Granular Attention Backend Control: Enables independent selection of attention backends for the draft model's prefill (draft_extend_attn_backend) and decode (draft_attn_backend) operations, moving beyond a single global attention_backend setting.
  • Code Refactoring for Clarity: The init_attention_backend method has been refactored into smaller, more manageable private helper methods (_create_decode_backend, _create_draft_extend_backend, and specific create*_backend functions for each attention type), improving code organization and maintainability.
  • Dynamic Backend Initialization: Implements a dynamic mapping system to initialize the correct attention backend based on the decode_attention_backend and prefill_attention_backend arguments provided to the server.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the attention backend initialization for the draft model, allowing separate backends for prefill and decode modes. The changes significantly improve code structure and maintainability by breaking down a large if/elif block into smaller, more focused methods. My review includes a couple of suggestions to further reduce code duplication and improve clarity.

Comment thread python/sglang/srt/speculative/eagle_worker.py Outdated
Comment on lines +309 to +322
def _create_trtllm_mla_decode_backend(self):
if not global_server_args_dict["use_mla_backend"]:
raise ValueError(
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
)
self.draft_extend_attn_backend = TRTLLMMLABackend(
self.draft_model_runner,
skip_prefill=False,

from sglang.srt.layers.attention.trtllm_mla_backend import (
TRTLLMMLAMultiStepDraftBackend,
)

self.has_prefill_wrapper_verify = True
return TRTLLMMLAMultiStepDraftBackend(
self.draft_model_runner, self.topk, self.speculative_num_steps
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The check for use_mla_backend is duplicated in _create_trtllm_mla_decode_backend and _create_trtllm_mla_prefill_backend. To improve maintainability, you can extract this logic into a helper method.

For example, you can add this helper method to the class:

def _check_mla_backend(self):
    if not global_server_args_dict["use_mla_backend"]:
        raise ValueError(
            "trtllm_mla backend requires MLA model (use_mla_backend=True)."
        )

Then you can call self._check_mla_backend() here and in _create_trtllm_mla_prefill_backend.

    def _create_trtllm_mla_decode_backend(self):
        self._check_mla_backend()

        from sglang.srt.layers.attention.trtllm_mla_backend import (
            TRTLLMMLAMultiStepDraftBackend,
        )

        self.has_prefill_wrapper_verify = True
        return TRTLLMMLAMultiStepDraftBackend(
            self.draft_model_runner, self.topk, self.speculative_num_steps
        )

Copy link
Copy Markdown
Collaborator

@Qiaolin-Yu Qiaolin-Yu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the multiple _create_xxx_decode_backend functions, would writing it following the code here makes the implementation simpler?

@Ximingwang-09
Copy link
Copy Markdown
Contributor Author

For the multiple _create_xxx_decode_backend functions, would writing it following the code here makes the implementation simpler?

Thanks for the suggestion! I think multiple _create_xxx_decode_backend functions avoid large if/else that mixes multiple backend details together. And keeps the main method short and focused. I currently feel the separate methods give us cleaner structure as the number of backends grows.That said, do you think the potential clarity of a single if/else still outweighs these benefits in our case?

Copy link
Copy Markdown
Collaborator

@Qiaolin-Yu Qiaolin-Yu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm. wait for ci.

@Qiaolin-Yu
Copy link
Copy Markdown
Collaborator

For the multiple _create_xxx_decode_backend functions, would writing it following the code here makes the implementation simpler?

Thanks for the suggestion! I think multiple _create_xxx_decode_backend functions avoid large if/else that mixes multiple backend details together. And keeps the main method short and focused. I currently feel the separate methods give us cleaner structure as the number of backends grows.That said, do you think the potential clarity of a single if/else still outweighs these benefits in our case?

hmm. I think maybe using if/else here could make it easier to read, but just a nit.

@zhyncs zhyncs merged commit df397a7 into sgl-project:main Sep 3, 2025
125 of 135 checks passed
MahmoudAshraf97 pushed a commit to MahmoudAshraf97/sglang that referenced this pull request Sep 8, 2025
Co-authored-by: 纬杭 <ximing.wxm@antgroup.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants