Skip to content

[diffusion] disaggregated diffusion#21701

Merged
mickqian merged 17 commits intosgl-project:mainfrom
yhyang201:disaggregated-diffusion
Apr 16, 2026
Merged

[diffusion] disaggregated diffusion#21701
mickqian merged 17 commits intosgl-project:mainfrom
yhyang201:disaggregated-diffusion

Conversation

@yhyang201
Copy link
Copy Markdown
Collaborator

@yhyang201 yhyang201 commented Mar 30, 2026

Motivation

The performance gains are uncertain.

Modifications

Accuracy Tests

Speed Tests and Profiling

Checklist

Review and Merge Process

  1. Ping Merge Oncalls to start the process. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • Common commands include /tag-and-rerun-ci, /tag-run-ci-label, /rerun-failed-ci
  4. After green CI and required approvals, ask Merge Oncalls or people with Write permission to merge the PR.

@github-actions github-actions Bot added documentation Improvements or additions to documentation diffusion SGLang Diffusion labels Mar 30, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a disaggregated diffusion pipeline architecture, allowing the Encoder, Denoiser, and Decoder roles to run on independent GPU instances. It includes a central DiffusionServer for request routing, a P2P transfer engine for tensor movement, and updated scheduler logic to support these roles. My review identified several areas for improvement, including fixing an incorrect average latency calculation, cleaning up redundant code and dataclass initializations, and addressing potential issues with the transfer protocol's handling of result frames.

Comment on lines +108 to +109
total = self._completed + self._failed
avg_latency = self._total_latency / total if total > 0 else 0.0
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The average latency calculation seems to include failed requests in the denominator (total = self._completed + self._failed), but _total_latency is only updated for completed requests. This will result in an incorrect (underestimated) average latency. The denominator should probably just be self._completed.

Suggested change
total = self._completed + self._failed
avg_latency = self._total_latency / total if total > 0 else 0.0
total = self._completed
avg_latency = self._total_latency / total if total > 0 else 0.0

else:
raise TypeError(f"Cannot encode transfer message: {type(msg)}")

d.pop("result_frames", None)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The result_frames field is popped from the dictionary before JSON serialization in encode_transfer_msg. This means that msg.get("result_frames") in diffusion_server.py's _handle_transfer_done will always be None, and the branch that handles result_frames (and calls _transfer_return_to_client) is effectively dead code. If the intention is to transfer result frames via this message, they should not be popped here.

Comment on lines +52 to +65
manifest: dict = None
scalar_fields: dict = None
receiver_session_id: str = ""
receiver_pool_ptr: int = 0
receiver_slot_offset: int = 0
sender_instance: int = -1
receiver_instance: int = -1
prealloc_slot_id: int | None = None

def __post_init__(self):
if self.manifest is None:
self.manifest = {}
if self.scalar_fields is None:
self.scalar_fields = {}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The __post_init__ method is used here to initialize manifest and scalar_fields to empty dictionaries if they are None. A more idiomatic way to handle mutable default arguments in dataclasses is to use field(default_factory=dict). This also simplifies other parts of the code, for example line 748 could become scalar_fields = dict(p2p.scalar_fields) without the conditional check.

Suggested change
manifest: dict = None
scalar_fields: dict = None
receiver_session_id: str = ""
receiver_pool_ptr: int = 0
receiver_slot_offset: int = 0
sender_instance: int = -1
receiver_instance: int = -1
prealloc_slot_id: int | None = None
def __post_init__(self):
if self.manifest is None:
self.manifest = {}
if self.scalar_fields is None:
self.scalar_fields = {}
manifest: dict = field(default_factory=dict)
scalar_fields: dict = field(default_factory=dict)
receiver_session_id: str = ""
receiver_pool_ptr: int = 0
receiver_slot_offset: int = 0
sender_instance: int = -1
receiver_instance: int = -1
prealloc_slot_id: int | None = None

"pool_size": msg.get("pool_size", 0),
}
prealloc = msg.get("preallocated_slots", [])
info["free_preallocated_slots"] = list(prealloc) if prealloc else []
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The expression list(prealloc) if prealloc else [] is redundant. list(prealloc) will produce an empty list if prealloc is an empty list. You can simplify this.

Suggested change
info["free_preallocated_slots"] = list(prealloc) if prealloc else []
info["free_preallocated_slots"] = list(prealloc)

encode_transfer_msg(alloc_msg)
)

def _transfer_return_to_client(self, request_id: str, result_frames: list) -> None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This function _transfer_return_to_client appears to be dead code. It is only called from _handle_transfer_done if result_frames is present in the message from the decoder. However, encode_transfer_msg in protocol.py explicitly removes result_frames before serialization, so this path is never taken. The main result path for decoders seems to be _handle_decoder_result_frames for non-transfer messages.

if scheduler_mod is not None and num_steps is not None:
device = torch.device(f"cuda:{self.worker.local_rank}")
extra_kwargs = {}
mu = req.extra.get("mu") if hasattr(req, "extra") else None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The check hasattr(req, "extra") is redundant because the Req dataclass initializes extra with a default_factory=dict, so it will always be present. You can simplify this line.

Suggested change
mu = req.extra.get("mu") if hasattr(req, "extra") else None
mu = req.extra.get("mu")

if scheduler_mod is not None and num_steps is not None:
device = torch.device(local_device)
extra_kwargs = {}
mu = req.extra.get("mu") if hasattr(req, "extra") else None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The check hasattr(req, "extra") is redundant because the Req dataclass initializes extra with a default_factory=dict, so it will always be present. You can simplify this line.

Suggested change
mu = req.extra.get("mu") if hasattr(req, "extra") else None
mu = req.extra.get("mu")

Comment on lines +40 to +47
manifest: dict = None
session_id: str = ""
pool_ptr: int = 0
slot_offset: int = 0

def __post_init__(self):
if self.manifest is None:
self.manifest = {}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The __post_init__ method is used to initialize manifest to an empty dictionary if it is None. A more idiomatic way to handle mutable default arguments in dataclasses is to use field(default_factory=dict). This makes the __post_init__ method unnecessary.

Suggested change
manifest: dict = None
session_id: str = ""
pool_ptr: int = 0
slot_offset: int = 0
def __post_init__(self):
if self.manifest is None:
self.manifest = {}
manifest: dict = field(default_factory=dict)
session_id: str = ""
pool_ptr: int = 0
slot_offset: int = 0

Comment on lines +87 to +95
manifest: dict = None
slot_offset: int = 0
scalar_fields: dict = None

def __post_init__(self):
if self.manifest is None:
self.manifest = {}
if self.scalar_fields is None:
self.scalar_fields = {}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The __post_init__ method is used here to initialize manifest and scalar_fields to empty dictionaries if they are None. A more idiomatic way to handle mutable default arguments in dataclasses is to use field(default_factory=dict). This makes the __post_init__ method unnecessary.

Suggested change
manifest: dict = None
slot_offset: int = 0
scalar_fields: dict = None
def __post_init__(self):
if self.manifest is None:
self.manifest = {}
if self.scalar_fields is None:
self.scalar_fields = {}
manifest: dict = field(default_factory=dict)
slot_offset: int = 0
scalar_fields: dict = field(default_factory=dict)

@yhyang201 yhyang201 marked this pull request as ready for review April 1, 2026 14:39
@yhyang201 yhyang201 force-pushed the disaggregated-diffusion branch from 8cd36d9 to 85bdc28 Compare April 1, 2026 15:05
@mickqian
Copy link
Copy Markdown
Collaborator

mickqian commented Apr 1, 2026

/tag-and-rerun-ci

@yhyang201
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

2 similar comments
@yhyang201
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

@yhyang201
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

| `--disagg-role` | What it runs |
|----------------|--------------|
| `monolithic` | (Default) Standard single-server mode |
| `encoder` | Encoder role instance (text/image encoding) |
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better specify other infos, e.g., TimestepPrepationStage is also included

@@ -0,0 +1,235 @@
# Disaggregated Diffusion Pipeline
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider move this to docs/diffusion, the current folder is deprecated

return order

@staticmethod
def _next_power_of_2(n: int) -> int:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we cache this by lru_cache


@property
def role_affinity(self):
from sglang.multimodal_gen.runtime.disaggregation.roles import RoleType
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please avoid lazy import

ib_device = getattr(sa, "disagg_ib_device", None)
engine = create_transfer_engine(
hostname=hostname,
gpu_id=self.gpu_id,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we expecting something like a physical rank number (local_rank) here? gpu_id is the rank number within the process

transfer_state: _TransferRequestState | None = None


class DiffusionServer:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we change the name of this file to something like orchestrator too?

pass

# Fast path: use pre-allocated slot if available and large enough
peer_info = self._denoiser_peers.get(denoiser_idx, {})
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • _denoiser_peers follows the order of registration in _handle_transfer_register
  • _denoiser_pushes follows the order of passed --denoiser-urls, passed from parse_url_string -> endpoints list

as a result, when we're doing something like:

   peer_info = self._denoiser_peers.get(denoiser_idx, {})
   self._denoiser_pushes[denoiser_idx].send_multipart(encode_transfer_msg(alloc_msg))

here, we're mixing two instances, sending control msg to A, while using session id / pool_ptr from B as RDMA address.

This could be fixed via forcing the _denoiser_pushes follows the order of passing denoiser-urls

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also nit: we need to extract the logics after these fast path to dedicated helper functions

@yhyang201 yhyang201 force-pushed the disaggregated-diffusion branch from 04592b6 to 01df50d Compare April 13, 2026 15:25
@yhyang201
Copy link
Copy Markdown
Collaborator Author

/tag-and-rerun-ci

@yhyang201
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

3 similar comments
@yhyang201
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

@yhyang201
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

@yhyang201
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

- Add ready event to DiffusionServer so launch waits for sockets to bind
  before starting the HTTP server (fixes race condition)
- Handle encoder/denoiser error results that arrive as non-transfer
  messages instead of silently dropping them (fixes silent timeout)
- Log response body on HTTP errors and dump role log tails in
  tearDownClass for CI debugging

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@mickqian mickqian merged commit 9da998a into sgl-project:main Apr 16, 2026
96 of 109 checks passed
jmamou pushed a commit to jmamou/sglang that referenced this pull request Apr 20, 2026
yhyang201 added a commit to yhyang201/sglang that referenced this pull request Apr 22, 2026
zhangying098 pushed a commit to zhangying098/sglang that referenced this pull request Apr 23, 2026
kyx1999 pushed a commit to KMSorSMS/sglang that referenced this pull request Apr 27, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

diffusion SGLang Diffusion documentation Improvements or additions to documentation high priority run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants