sequenceDiagram
participant C as Client
participant PP0 as first_pp_rank<br/>(PP0)
participant PPn as middle_pp_ranks<br/>(PP-1..PP-N-2)
participant PPL as last_pp_rank<br/>(PP-N-1)
note over PP0: Hosts: embed + mtp_layer + lm_head<br/>+ target model (first PP shard)
note over PPL: Hosts: embed + mtp_layer + lm_head<br/>+ target model (last PP shard)
note over PPn: Hosts: target model (middle PP shard)
C->>PP0: Decode Request
note over PP0: Draft Phase
PP0->>PP0: set_kv_cache(kv_buffer, index_k_with_scale_buffer) <br/> received from last PP-N-1
PP0->>PP0: draft() → EagleDraftOutput<br/> (also includes the kv_buffer, index_k_with_scale_buffer values for new tokens populated)
note over PP0: Tree Building
PP0->>PP0: build_tree_kernel_efficient() → EagleVerifyInput
note over PP0: Target Forward
PP0->>PP0: forward() -> hidden_states
PP0->>PPn: PPProxyTensors(hidden_states, EagleDraftOutput)
note over PPn: Tree Building
PPn->>PPn: build_tree_kernel_efficient() → EagleVerifyInput
note over PPn: Target Forward
PPn->>PPn: forward() -> hidden_states
PPn->>PPL: PPProxyTensors(hidden_states, EagleDraftOutput)
note over PPL: Tree Building
PPL->>PPL: build_tree_kernel_efficient() → EagleVerifyInput
note over PPL: Target Forward + Verify
PPL->>PPL: foward() → hidden_states, logits
PPL->>PPL: verify() -> verify_metadata[accepted_tokens, pages_to_free,...]
note over PPL: Draft Extend
PPL->>PPL: set_kv_cache(kv_buffer, index_k_with_scale_buffer)
PPL->>PPL: forward_draft_extend_after_decode() -> EagleDraftOutput <br/> (kv_buffer, index_k_with_scale_buffer values,...)
PPL-->>PP0: PPProxyTensors(verify_metadata, EagleDraftOutput, ...)
note over PP0: Post-Process
PP0->>PP0: _prep_batch_result()<br/>post_process_after_verify(verify_metadata)<br/>→ update request states, free KV-slots
Checklist
Motivation
I'd briefly share my plan of action on how we can support Speculative Decoding + PP in Disagg Decode Mode. Feel free to drop your comments/suggestions.
Model Weights Layout
first_pp_rankandlast_pp_rank:lm_headon thefirst_pp_rankand duplicateembedon thelast_pp_rank.Flow
first_pp_rank->draft_tokens,scoreswould be sent to each PP-rank via thePPProxyTensorsalready implemented.build_tree_kernel_efficient()to getEagleVerifyInput.verify()to get theaccepted_tokens->verify_metadatasent to thefirst_pp_rank.forward_draft_extend_after_decode()._prep_batch_resultwould run thepost_process_after_verify(verify_metadata)to update request states (finished/unfinished) and free unused KV-slots on the first PP-rank.verify_metadatawould then be passed along to PP-1..PP-2 and so on for request state adjustment on each PP-rank.Diagram
Speculative Decoding + Pipeline Parallelism in Disagg Decode Mode (Flow of a Microbatch)
sequenceDiagram participant C as Client participant PP0 as first_pp_rank<br/>(PP0) participant PPn as middle_pp_ranks<br/>(PP-1..PP-N-2) participant PPL as last_pp_rank<br/>(PP-N-1) note over PP0: Hosts: embed + mtp_layer + lm_head<br/>+ target model (first PP shard) note over PPL: Hosts: embed + mtp_layer + lm_head<br/>+ target model (last PP shard) note over PPn: Hosts: target model (middle PP shard) C->>PP0: Decode Request note over PP0: Draft Phase PP0->>PP0: set_kv_cache(kv_buffer, index_k_with_scale_buffer) <br/> received from last PP-N-1 PP0->>PP0: draft() → EagleDraftOutput<br/> (also includes the kv_buffer, index_k_with_scale_buffer values for new tokens populated) note over PP0: Tree Building PP0->>PP0: build_tree_kernel_efficient() → EagleVerifyInput note over PP0: Target Forward PP0->>PP0: forward() -> hidden_states PP0->>PPn: PPProxyTensors(hidden_states, EagleDraftOutput) note over PPn: Tree Building PPn->>PPn: build_tree_kernel_efficient() → EagleVerifyInput note over PPn: Target Forward PPn->>PPn: forward() -> hidden_states PPn->>PPL: PPProxyTensors(hidden_states, EagleDraftOutput) note over PPL: Tree Building PPL->>PPL: build_tree_kernel_efficient() → EagleVerifyInput note over PPL: Target Forward + Verify PPL->>PPL: foward() → hidden_states, logits PPL->>PPL: verify() -> verify_metadata[accepted_tokens, pages_to_free,...] note over PPL: Draft Extend PPL->>PPL: set_kv_cache(kv_buffer, index_k_with_scale_buffer) PPL->>PPL: forward_draft_extend_after_decode() -> EagleDraftOutput <br/> (kv_buffer, index_k_with_scale_buffer values,...) PPL-->>PP0: PPProxyTensors(verify_metadata, EagleDraftOutput, ...) note over PP0: Post-Process PP0->>PP0: _prep_batch_result()<br/>post_process_after_verify(verify_metadata)<br/>→ update request states, free KV-slotsTarget Worker
EagleWorkermodule for this support (since PP doesn't support overlap mode anyways so there's no point working withEagleWorker2.Concerns
batchcan get modified when we get theget_next_disagg_decode_batch_to_run()on first PP rank. Bothdraft()anddraft_extend_after_decode()need to run on the samebatch, one before target model forward, other afterwards.draft()happens on first PP-rank and thedraft_extend_after_decode()happens on last PP-rank, there is a necessity to passincremental KV valuesacross PP-ranks -> Need some feedback on the feasibility of doing this.Related resources
Relevant Files:
python/sglang/srt/speculative/eagle_worker.pypython/sglang/srt/managers/scheduler_pp_mixin.pypython/sglang/srt/models/deepseek_nextn.py,python/sglang/srt/models/deepseek_v2.py(GLMMoeDsaForCausalLM's base class)