[AMD] Serialize cross-ProcessGroup collectives for dp_attention#11184
[AMD] Serialize cross-ProcessGroup collectives for dp_attention#11184hubertlu-tw wants to merge 3 commits intosgl-project:mainfrom
Conversation
Summary of ChangesHello @hubertlu-tw, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request resolves critical stream-capture issues encountered with ROCm 7.0 when performing collective operations in Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request addresses a stream capture issue on ROCm 7.0 for dp_attention by serializing cross-process-group collectives. The fix involves using an asynchronous reduce_scatter followed by a wait() before all_gather. My review identified a critical bug that would cause a NameError on non-HIP systems, an unused import, and an inconsistent use of a communication wrapper that could bypass optimizations. I have provided suggestions to resolve these issues.
| if _is_hip: | ||
| _USE_ROCM7 = get_rocm_version()[0] >= 7 |
There was a problem hiding this comment.
The variable _USE_ROCM7 is only defined within the if _is_hip: block. This will cause a NameError on non-HIP platforms where _is_hip is False, as _USE_ROCM7 will be referenced later in _dp_gather_via_all_gather without being defined. To fix this, _USE_ROCM7 should be defined regardless of the platform. A cleaner way to write this would be to combine the check for _is_hip into the assignment.
_USE_ROCM7 = _is_hip and get_rocm_version()[0] >= 7| torch.distributed.all_gather_into_tensor( | ||
| global_tokens, | ||
| scattered_local_tokens, | ||
| group=get_tp_group().device_group, | ||
| ) |
There was a problem hiding this comment.
For consistency and to leverage potential optimizations (like pynccl), it's better to use the GroupCoordinator wrapper get_tp_group().all_gather_into_tensor(...) here, similar to the else branch. The current implementation calls torch.distributed.all_gather_into_tensor directly, which bypasses the logic in the wrapper. The wrapper is designed to be graph-safe, so it should be suitable for this context.
get_tp_group().all_gather_into_tensor(global_tokens, scattered_local_tokens)|
@ch-wan could you please help review the PR? Thanks! |
HaiShaw
left a comment
There was a problem hiding this comment.
Also - any performance indicator?
| if _is_hip: | ||
| _USE_ROCM7 = get_rocm_version()[0] >= 7 |
…pact + sglang precedent analysis - Correct PyTorch #176251 status: merged twice and reverted twice (latest revert 2026-03-31), so the watchdog workaround is currently NOT in main. Empirical check confirms no public torch wheel (upstream nightly, AMD gfx950-dcgpu nightly, rocm/pytorch images) ships RocmWatchdogEventQueryContextGuard. Strike out "upgrade to a nightly" path (Fallback 1) and add Fallback 2 (cherry-pick #176942 patch into a local torch 2.9.1 rebuild) as the only software-side option for staying on ROCm 7.2.0. - Refine §3 v0.1.11 vs v0.1.12 narrative: both releases ship the fused allreduce+rmsnorm+quant kernel; the actual change is v0.1.12 introducing dynamic in-graph output-buffer registration (is_broadcast_reg_outptr -> get_output_buffer_RD), which is what doubles the per-AR host-side bookkeeping inside the capture window. - Expand §4 multi-PG impact analysis: enumerate the actual NCCL-bearing PGs under --tp 4 --ep 2 (TP, MOE_EP, MOE_TP, WORLD), explain per-capture cost compounding, and note that other models on the same nightly suite are statistically lucky rather than immune. - Add sgl-project#10434 / sgl-project#11184 precedent: rocm-7.0.0-alpha cross-PG capture issue had a one-line algorithmic workaround (DpPaddingMode.MAX_LEN -> SUM_LEN). aiter#2857 has no equivalent algorithmic out and depends on the runtime fix in ROCm 7.2.1+.
Co-author: @kkHuang-amd
Motivation
We previously landed a workaround in #10434 that switched _dp_gather_via_all_gather to _dp_gather_via_all_reduce to avoid an RCCL/HIP failure when DP was enabled on dsv3.
The root cause is stricter stream-capture checks in ROCm 7.0: chaining reduce_scatter_tensor on the attention TP process group immediately followed by all_gather_into_tensor on the TP process group inside a captured region can lead to
hipErrorCapturedEvent (“operation not permitted on an event last recorded in a capturing stream”). ROCm 7 updated event/callback behavior during capture (e.g.,hipEventQuery,hipStreamAddCallback) to match CUDA, so any out-of-order polling or cross-stream/event use triggers an error (https://rocm.docs.amd.com/projects/HIP/en/docs-develop/hip-7-changes.html#stream-capture-updates).NCCL/RCCL collectives are capturable, but they require stable streams and explicit ordering; PyTorch also documents that when using multiple process groups, outstanding async ops on one PG must be synchronized before issuing collectives on another.
This PR restores the all-gather path and makes it graph-safe by launching reduce_scatter_tensor(..., async_op=True) on the attention process group (from
get_attention_tp_group), then calling work.wait() before invoking all_gather_into_tensor on the TP process group (fromget_tp_group) (both on the same capturing stream). With TORCH_NCCL_BLOCKING_WAIT=1, the wait avoids host-side polling of captured events. This preserves the original RS→AG algorithm while complying with ROCm 7’s capture rules.Modifications
Accuracy Tests
Benchmarking and Profiling
Checklist
CC: @HaiShaw