Skip to content

[RFC] GLM-5 Spec. Decode + PP in Disaggregation Decode Mode #23162

@MMuzzammil1

Description

@MMuzzammil1

Checklist

Motivation

I'd briefly share my plan of action on how we can support Speculative Decoding + PP in Disagg Decode Mode. Feel free to drop your comments/suggestions.

Model Weights Layout

  • The speculative model weights will be duplicated loaded on first_pp_rank and last_pp_rank:
    • This implies duplicate lm_head on the first_pp_rank and duplicate embed on the last_pp_rank.
  • The target model weights will be set-up in the regular PP setting.

Flow

  • Draft Phase would happen on the first_pp_rank -> draft_tokens, scores would be sent to each PP-rank via the PPProxyTensors already implemented.
  • Each individual PP-rank would materialize build_tree_kernel_efficient() to get EagleVerifyInput.
  • The last PP-rank would run the verify() to get the accepted_tokens -> verify_metadata sent to the first_pp_rank.
  • The last PP-rank would also run forward_draft_extend_after_decode().
  • _prep_batch_result would run the post_process_after_verify(verify_metadata) to update request states (finished/unfinished) and free unused KV-slots on the first PP-rank.
  • verify_metadata would then be passed along to PP-1..PP-2 and so on for request state adjustment on each PP-rank.

Diagram

Speculative Decoding + Pipeline Parallelism in Disagg Decode Mode (Flow of a Microbatch)

sequenceDiagram
    participant C as Client
    participant PP0 as first_pp_rank<br/>(PP0)
    participant PPn as middle_pp_ranks<br/>(PP-1..PP-N-2)
    participant PPL as last_pp_rank<br/>(PP-N-1)

    note over PP0: Hosts: embed + mtp_layer + lm_head<br/>+ target model (first PP shard)
    note over PPL: Hosts: embed + mtp_layer + lm_head<br/>+ target model (last PP shard)
    note over PPn: Hosts: target model (middle PP shard)

    C->>PP0: Decode Request

    note over PP0: Draft Phase
    PP0->>PP0: set_kv_cache(kv_buffer, index_k_with_scale_buffer) <br/> received from last PP-N-1
    PP0->>PP0: draft() → EagleDraftOutput<br/> (also includes the kv_buffer, index_k_with_scale_buffer values for new tokens populated)

    note over PP0: Tree Building
    PP0->>PP0: build_tree_kernel_efficient() → EagleVerifyInput

    note over PP0: Target Forward
    PP0->>PP0: forward() -> hidden_states

    PP0->>PPn: PPProxyTensors(hidden_states, EagleDraftOutput)

    note over PPn: Tree Building
    PPn->>PPn: build_tree_kernel_efficient() → EagleVerifyInput

    note over PPn: Target Forward
    PPn->>PPn: forward() -> hidden_states

    PPn->>PPL: PPProxyTensors(hidden_states, EagleDraftOutput)

    note over PPL: Tree Building
    PPL->>PPL: build_tree_kernel_efficient() → EagleVerifyInput

    note over PPL: Target Forward + Verify
    PPL->>PPL: foward() → hidden_states, logits
    PPL->>PPL: verify() -> verify_metadata[accepted_tokens, pages_to_free,...]

    note over PPL: Draft Extend
    PPL->>PPL: set_kv_cache(kv_buffer, index_k_with_scale_buffer)
    PPL->>PPL: forward_draft_extend_after_decode() -> EagleDraftOutput <br/> (kv_buffer, index_k_with_scale_buffer values,...)
    PPL-->>PP0: PPProxyTensors(verify_metadata, EagleDraftOutput, ...)

    note over PP0: Post-Process
    PP0->>PP0: _prep_batch_result()<br/>post_process_after_verify(verify_metadata)<br/>→ update request states, free KV-slots
Loading

Target Worker

  • I plan to extend EagleWorker module for this support (since PP doesn't support overlap mode anyways so there's no point working with EagleWorker2.

Concerns

  • On why mtp layer needs to be loaded on both first and last PP ranks:
    • For the target model to verify, we need drafted token outputs before the target forward run -> need to obtain the draft tokens on the first PP rank itself.
    • batch can get modified when we get the get_next_disagg_decode_batch_to_run() on first PP rank. Both draft() and draft_extend_after_decode() need to run on the same batch, one before target model forward, other afterwards.
  • Since draft() happens on first PP-rank and the draft_extend_after_decode() happens on last PP-rank, there is a necessity to pass incremental KV values across PP-ranks -> Need some feedback on the feasibility of doing this.
  • Performance: Would an implementation like this be performant or not. Need some feedback on this as well.

Related resources

Relevant Files:

  • python/sglang/srt/speculative/eagle_worker.py
  • python/sglang/srt/managers/scheduler_pp_mixin.py
  • python/sglang/srt/models/deepseek_nextn.py, python/sglang/srt/models/deepseek_v2.py (GLMMoeDsaForCausalLM's base class)

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions