Enable Pipeline Parallelism support for Piecewise CUDA Graph #14515#14547
Enable Pipeline Parallelism support for Piecewise CUDA Graph #14515#14547baonudesifeizhai wants to merge 28 commits intosgl-project:mainfrom
Conversation
This reverts commit 6f48bbd.
Summary of ChangesHello @baonudesifeizhai, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates Pipeline Parallelism (PP) support into the existing piecewise CUDA graph functionality. Previously, these two performance optimizations were mutually exclusive. The changes involve removing the explicit disabling of piecewise CUDA graphs when PP is active, and implementing the necessary mechanisms within the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request enables pipeline parallelism support for piecewise CUDA graphs, which is a significant enhancement. The changes correctly handle pp_proxy_tensors during graph capture and replay by using a pre-allocated buffer, and include a nice performance optimization by caching the model signature check. However, I've identified a critical issue that will cause a TypeError at runtime, and a minor inconsistency in the handling of mrope_positions that could affect maintainability and resource usage. Please see the detailed comments for suggestions on how to address these points. Overall, great work on tackling this complex feature.
| self.mrope_positions = torch.zeros( | ||
| (3, self.max_num_tokens), dtype=torch.int64 | ||
| ) |
There was a problem hiding this comment.
The initialization of self.mrope_positions has been moved out of the if self.is_multimodal: block. While this might be intentional if mrope is used by non-multimodal models, the comment on line 234 now becomes misleading as it states mrope_positions is only for multimodal models. Additionally, other parts of the code still check self.is_multimodal before using mrope_positions (e.g., lines 354, 439). This creates an inconsistency.
To improve clarity and correctness, please either:
- Move the initialization back inside the
if self.is_multimodal:block if it's only for multimodal models. - If it's used more broadly, update the comment on line 234 and change the checks from
self.is_multimodalto a more appropriate condition (e.g.,self.model_runner.model_is_mrope).
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
ispobock
left a comment
There was a problem hiding this comment.
Could you add an unit test?
|
python3 test_piecewise_cuda_graph.py TestPiecewiseCudaGraphWithPP.test_gsm8k_accuracy pass |
|
Could you fix the lint issue? |
|
already fixed!
|
|
/tag-and-rerun-ci |
|
@baonudesifeizhai Could you resolve the conflict? |
|
python -m pytest test/srt/test_piecewise_cuda_graph_2_gpu.py::TestPiecewiseCudaGraphWithPP -v passed |
Oasis-Git
left a comment
There was a problem hiding this comment.
In general the function is good. However before merge we may need to do some code modification.
| "Disable piecewise CUDA graph because piecewise_cuda_graph does not support PP", | ||
| ) | ||
| return False | ||
| # PP support is now enabled for piecewise CUDA graph |
| self.mrope_positions = torch.zeros( | ||
| (3, self.max_num_tokens), dtype=torch.int64 | ||
| ) | ||
| self.mrope_positions = torch.zeros( |
There was a problem hiding this comment.
Does PP need mrope or not? if not allocate the buffer when is_multimodal is True
| def replay_prepare( | ||
| self, | ||
| forward_batch: ForwardBatch, | ||
| pp_proxy_tensors: Optional[PPProxyTensors] = None, |
There was a problem hiding this comment.
Since you do not use pp_proxy_tensors please remove it
| with enable_piecewise_cuda_graph(), disable_ca_comm(self.model_runner.tp_group): | ||
| self.model_runner.attn_backend.init_forward_metadata(forward_batch) | ||
| static_forward_batch = self.replay_prepare(forward_batch, **kwargs) | ||
| # Extract pp_proxy_tensors from kwargs if present (avoid in-place modification) |
There was a problem hiding this comment.
With original replay_prepare there is no need to split pp_proxy_tensors from kwargs
| # Note: piecewise captures with bs=1, but we need buffer for PP proxy tensors | ||
| # The buffer size is 1 since we capture with batch_size=1 | ||
| self.pp_proxy_tensors_buffer = { | ||
| "hidden_states": torch.zeros( |
There was a problem hiding this comment.
Why u need self.pp_proxy_tensors_buffer and why the shape is [1, hidden_dim] instead of ['num_tokens, hidden_dim] here
|
/tag-and-rerun-ci |
|
works normal after fix conflicts |

Motivation
Previously, piecewise CUDA graph was explicitly disabled when pipeline parallelism (PP) was enabled. This PR enables PP support for piecewise CUDA graph, allowing users to benefit from both optimizations simultaneously.
#14515
Modifications
ModelRunner.can_run_piecewise_cuda_graph()PiecewiseCudaGraphRunner.__init__()PiecewiseCudaGraphRunner.capture_one_batch_size()for graph capturePiecewiseCudaGraphRunner.replay()to properly pass tensors from pre-allocated buffers during replayinspect.signature()calls in the hot pathAccuracy Tests
https://paste.ubuntu.com/p/KWKwszGRvd/
res:
]
Benchmarking and Profiling
Checklist