[diffusion] disaggregated diffusion#21701
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a disaggregated diffusion pipeline architecture, allowing the Encoder, Denoiser, and Decoder roles to run on independent GPU instances. It includes a central DiffusionServer for request routing, a P2P transfer engine for tensor movement, and updated scheduler logic to support these roles. My review identified several areas for improvement, including fixing an incorrect average latency calculation, cleaning up redundant code and dataclass initializations, and addressing potential issues with the transfer protocol's handling of result frames.
| total = self._completed + self._failed | ||
| avg_latency = self._total_latency / total if total > 0 else 0.0 |
There was a problem hiding this comment.
The average latency calculation seems to include failed requests in the denominator (total = self._completed + self._failed), but _total_latency is only updated for completed requests. This will result in an incorrect (underestimated) average latency. The denominator should probably just be self._completed.
| total = self._completed + self._failed | |
| avg_latency = self._total_latency / total if total > 0 else 0.0 | |
| total = self._completed | |
| avg_latency = self._total_latency / total if total > 0 else 0.0 |
| else: | ||
| raise TypeError(f"Cannot encode transfer message: {type(msg)}") | ||
|
|
||
| d.pop("result_frames", None) |
There was a problem hiding this comment.
The result_frames field is popped from the dictionary before JSON serialization in encode_transfer_msg. This means that msg.get("result_frames") in diffusion_server.py's _handle_transfer_done will always be None, and the branch that handles result_frames (and calls _transfer_return_to_client) is effectively dead code. If the intention is to transfer result frames via this message, they should not be popped here.
| manifest: dict = None | ||
| scalar_fields: dict = None | ||
| receiver_session_id: str = "" | ||
| receiver_pool_ptr: int = 0 | ||
| receiver_slot_offset: int = 0 | ||
| sender_instance: int = -1 | ||
| receiver_instance: int = -1 | ||
| prealloc_slot_id: int | None = None | ||
|
|
||
| def __post_init__(self): | ||
| if self.manifest is None: | ||
| self.manifest = {} | ||
| if self.scalar_fields is None: | ||
| self.scalar_fields = {} |
There was a problem hiding this comment.
The __post_init__ method is used here to initialize manifest and scalar_fields to empty dictionaries if they are None. A more idiomatic way to handle mutable default arguments in dataclasses is to use field(default_factory=dict). This also simplifies other parts of the code, for example line 748 could become scalar_fields = dict(p2p.scalar_fields) without the conditional check.
| manifest: dict = None | |
| scalar_fields: dict = None | |
| receiver_session_id: str = "" | |
| receiver_pool_ptr: int = 0 | |
| receiver_slot_offset: int = 0 | |
| sender_instance: int = -1 | |
| receiver_instance: int = -1 | |
| prealloc_slot_id: int | None = None | |
| def __post_init__(self): | |
| if self.manifest is None: | |
| self.manifest = {} | |
| if self.scalar_fields is None: | |
| self.scalar_fields = {} | |
| manifest: dict = field(default_factory=dict) | |
| scalar_fields: dict = field(default_factory=dict) | |
| receiver_session_id: str = "" | |
| receiver_pool_ptr: int = 0 | |
| receiver_slot_offset: int = 0 | |
| sender_instance: int = -1 | |
| receiver_instance: int = -1 | |
| prealloc_slot_id: int | None = None | |
| "pool_size": msg.get("pool_size", 0), | ||
| } | ||
| prealloc = msg.get("preallocated_slots", []) | ||
| info["free_preallocated_slots"] = list(prealloc) if prealloc else [] |
There was a problem hiding this comment.
| encode_transfer_msg(alloc_msg) | ||
| ) | ||
|
|
||
| def _transfer_return_to_client(self, request_id: str, result_frames: list) -> None: |
There was a problem hiding this comment.
This function _transfer_return_to_client appears to be dead code. It is only called from _handle_transfer_done if result_frames is present in the message from the decoder. However, encode_transfer_msg in protocol.py explicitly removes result_frames before serialization, so this path is never taken. The main result path for decoders seems to be _handle_decoder_result_frames for non-transfer messages.
| if scheduler_mod is not None and num_steps is not None: | ||
| device = torch.device(f"cuda:{self.worker.local_rank}") | ||
| extra_kwargs = {} | ||
| mu = req.extra.get("mu") if hasattr(req, "extra") else None |
There was a problem hiding this comment.
| if scheduler_mod is not None and num_steps is not None: | ||
| device = torch.device(local_device) | ||
| extra_kwargs = {} | ||
| mu = req.extra.get("mu") if hasattr(req, "extra") else None |
There was a problem hiding this comment.
| manifest: dict = None | ||
| session_id: str = "" | ||
| pool_ptr: int = 0 | ||
| slot_offset: int = 0 | ||
|
|
||
| def __post_init__(self): | ||
| if self.manifest is None: | ||
| self.manifest = {} |
There was a problem hiding this comment.
The __post_init__ method is used to initialize manifest to an empty dictionary if it is None. A more idiomatic way to handle mutable default arguments in dataclasses is to use field(default_factory=dict). This makes the __post_init__ method unnecessary.
| manifest: dict = None | |
| session_id: str = "" | |
| pool_ptr: int = 0 | |
| slot_offset: int = 0 | |
| def __post_init__(self): | |
| if self.manifest is None: | |
| self.manifest = {} | |
| manifest: dict = field(default_factory=dict) | |
| session_id: str = "" | |
| pool_ptr: int = 0 | |
| slot_offset: int = 0 | |
| manifest: dict = None | ||
| slot_offset: int = 0 | ||
| scalar_fields: dict = None | ||
|
|
||
| def __post_init__(self): | ||
| if self.manifest is None: | ||
| self.manifest = {} | ||
| if self.scalar_fields is None: | ||
| self.scalar_fields = {} |
There was a problem hiding this comment.
The __post_init__ method is used here to initialize manifest and scalar_fields to empty dictionaries if they are None. A more idiomatic way to handle mutable default arguments in dataclasses is to use field(default_factory=dict). This makes the __post_init__ method unnecessary.
| manifest: dict = None | |
| slot_offset: int = 0 | |
| scalar_fields: dict = None | |
| def __post_init__(self): | |
| if self.manifest is None: | |
| self.manifest = {} | |
| if self.scalar_fields is None: | |
| self.scalar_fields = {} | |
| manifest: dict = field(default_factory=dict) | |
| slot_offset: int = 0 | |
| scalar_fields: dict = field(default_factory=dict) | |
8cd36d9 to
85bdc28
Compare
|
/tag-and-rerun-ci |
|
/rerun-failed-ci |
2 similar comments
|
/rerun-failed-ci |
|
/rerun-failed-ci |
| | `--disagg-role` | What it runs | | ||
| |----------------|--------------| | ||
| | `monolithic` | (Default) Standard single-server mode | | ||
| | `encoder` | Encoder role instance (text/image encoding) | |
There was a problem hiding this comment.
better specify other infos, e.g., TimestepPrepationStage is also included
| @@ -0,0 +1,235 @@ | |||
| # Disaggregated Diffusion Pipeline | |||
There was a problem hiding this comment.
consider move this to docs/diffusion, the current folder is deprecated
| return order | ||
|
|
||
| @staticmethod | ||
| def _next_power_of_2(n: int) -> int: |
There was a problem hiding this comment.
could we cache this by lru_cache
|
|
||
| @property | ||
| def role_affinity(self): | ||
| from sglang.multimodal_gen.runtime.disaggregation.roles import RoleType |
| ib_device = getattr(sa, "disagg_ib_device", None) | ||
| engine = create_transfer_engine( | ||
| hostname=hostname, | ||
| gpu_id=self.gpu_id, |
There was a problem hiding this comment.
are we expecting something like a physical rank number (local_rank) here? gpu_id is the rank number within the process
| transfer_state: _TransferRequestState | None = None | ||
|
|
||
|
|
||
| class DiffusionServer: |
There was a problem hiding this comment.
should we change the name of this file to something like orchestrator too?
| pass | ||
|
|
||
| # Fast path: use pre-allocated slot if available and large enough | ||
| peer_info = self._denoiser_peers.get(denoiser_idx, {}) |
There was a problem hiding this comment.
_denoiser_peersfollows the order of registration in_handle_transfer_register_denoiser_pushesfollows the order of passed--denoiser-urls, passed fromparse_url_string-> endpoints list
as a result, when we're doing something like:
peer_info = self._denoiser_peers.get(denoiser_idx, {})
self._denoiser_pushes[denoiser_idx].send_multipart(encode_transfer_msg(alloc_msg))
here, we're mixing two instances, sending control msg to A, while using session id / pool_ptr from B as RDMA address.
This could be fixed via forcing the _denoiser_pushes follows the order of passing denoiser-urls
There was a problem hiding this comment.
also nit: we need to extract the logics after these fast path to dedicated helper functions
04592b6 to
01df50d
Compare
|
/tag-and-rerun-ci |
|
/rerun-failed-ci |
3 similar comments
|
/rerun-failed-ci |
|
/rerun-failed-ci |
|
/rerun-failed-ci |
- Add ready event to DiffusionServer so launch waits for sockets to bind before starting the HTTP server (fixes race condition) - Handle encoder/denoiser error results that arrive as non-transfer messages instead of silently dropping them (fixes silent timeout) - Log response body on HTTP errors and dump role log tails in tearDownClass for CI debugging Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Fixes CI lint failure flagged by ruff. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…ults, and setattr routing
Motivation
The performance gains are uncertain.
Modifications
Accuracy Tests
Speed Tests and Profiling
Checklist
Review and Merge Process
/tag-and-rerun-ci,/tag-run-ci-label,/rerun-failed-ci