Skip to content

Enable Pipeline Parallelism support for Piecewise CUDA Graph #14515#14547

Open
baonudesifeizhai wants to merge 28 commits intosgl-project:mainfrom
baonudesifeizhai:feature/piecewise-cuda-graph-pp-support
Open

Enable Pipeline Parallelism support for Piecewise CUDA Graph #14515#14547
baonudesifeizhai wants to merge 28 commits intosgl-project:mainfrom
baonudesifeizhai:feature/piecewise-cuda-graph-pp-support

Conversation

@baonudesifeizhai
Copy link
Copy Markdown
Contributor

@baonudesifeizhai baonudesifeizhai commented Dec 6, 2025

Motivation

Previously, piecewise CUDA graph was explicitly disabled when pipeline parallelism (PP) was enabled. This PR enables PP support for piecewise CUDA graph, allowing users to benefit from both optimizations simultaneously.

#14515

Modifications

  • Removed the PP size check that disabled piecewise CUDA graph in ModelRunner.can_run_piecewise_cuda_graph()
  • Added PP proxy tensors buffer initialization in PiecewiseCudaGraphRunner.__init__()
  • Added PP proxy tensors handling in PiecewiseCudaGraphRunner.capture_one_batch_size() for graph capture
  • Added PP proxy tensors handling in PiecewiseCudaGraphRunner.replay() to properly pass tensors from pre-allocated buffers during replay
  • Cached model signature check to avoid expensive inspect.signature() calls in the hot path

Accuracy Tests

https://paste.ubuntu.com/p/KWKwszGRvd/
res:
] image

Benchmarking and Profiling

Checklist

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @baonudesifeizhai, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates Pipeline Parallelism (PP) support into the existing piecewise CUDA graph functionality. Previously, these two performance optimizations were mutually exclusive. The changes involve removing the explicit disabling of piecewise CUDA graphs when PP is active, and implementing the necessary mechanisms within the PiecewiseCudaGraphRunner to correctly manage and pass proxy tensors required for PP during both CUDA graph capture and replay, thereby allowing users to leverage both optimizations simultaneously for improved performance.

Highlights

  • Enabled Pipeline Parallelism (PP) for Piecewise CUDA Graphs: The core change allows these two optimization techniques to work together, which was previously disabled.
  • Removed PP Size Check: The explicit check in ModelRunner.can_run_piecewise_cuda_graph() that prevented piecewise CUDA graphs from running with pp_size > 1 has been removed.
  • Initialized PP Proxy Tensors Buffer: A buffer for pp_proxy_tensors (hidden states and residual) is now initialized in PiecewiseCudaGraphRunner.__init__() when PP is enabled.
  • Handled PP Proxy Tensors during Graph Capture: Logic was added in PiecewiseCudaGraphRunner.capture_one_batch_size() to create and pass pp_proxy_tensors during the CUDA graph capture phase.
  • Managed PP Proxy Tensors during Replay: The PiecewiseCudaGraphRunner.replay() method now properly handles and passes pp_proxy_tensors from pre-allocated buffers to the model's forward pass.
  • Cached Model Signature Check: An expensive inspect.signature() call to check for pp_proxy_tensors parameter in model.forward is now cached to improve performance in hot paths.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enables pipeline parallelism support for piecewise CUDA graphs, which is a significant enhancement. The changes correctly handle pp_proxy_tensors during graph capture and replay by using a pre-allocated buffer, and include a nice performance optimization by caching the model signature check. However, I've identified a critical issue that will cause a TypeError at runtime, and a minor inconsistency in the handling of mrope_positions that could affect maintainability and resource usage. Please see the detailed comments for suggestions on how to address these points. Overall, great work on tackling this complex feature.

Comment thread python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py Outdated
Comment on lines +242 to +244
self.mrope_positions = torch.zeros(
(3, self.max_num_tokens), dtype=torch.int64
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The initialization of self.mrope_positions has been moved out of the if self.is_multimodal: block. While this might be intentional if mrope is used by non-multimodal models, the comment on line 234 now becomes misleading as it states mrope_positions is only for multimodal models. Additionally, other parts of the code still check self.is_multimodal before using mrope_positions (e.g., lines 354, 439). This creates an inconsistency.

To improve clarity and correctness, please either:

  1. Move the initialization back inside the if self.is_multimodal: block if it's only for multimodal models.
  2. If it's used more broadly, update the comment on line 234 and change the checks from self.is_multimodal to a more appropriate condition (e.g., self.model_runner.model_is_mrope).

baonudesifeizhai and others added 2 commits December 6, 2025 15:11
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Copy link
Copy Markdown
Collaborator

@ispobock ispobock left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add an unit test?

@baonudesifeizhai
Copy link
Copy Markdown
Contributor Author

python3 test_piecewise_cuda_graph.py TestPiecewiseCudaGraphWithPP.test_gsm8k_accuracy pass

@ispobock
Copy link
Copy Markdown
Collaborator

Could you fix the lint issue?

@baonudesifeizhai
Copy link
Copy Markdown
Contributor Author

already fixed!

Could you fix the lint issue?

@ispobock
Copy link
Copy Markdown
Collaborator

ispobock commented Dec 13, 2025

/tag-and-rerun-ci

@ispobock
Copy link
Copy Markdown
Collaborator

@baonudesifeizhai Could you resolve the conflict?

@baonudesifeizhai
Copy link
Copy Markdown
Contributor Author

python -m pytest test/srt/test_piecewise_cuda_graph_2_gpu.py::TestPiecewiseCudaGraphWithPP -v passed

Copy link
Copy Markdown
Collaborator

@Oasis-Git Oasis-Git left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general the function is good. However before merge we may need to do some code modification.

"Disable piecewise CUDA graph because piecewise_cuda_graph does not support PP",
)
return False
# PP support is now enabled for piecewise CUDA graph
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this line

self.mrope_positions = torch.zeros(
(3, self.max_num_tokens), dtype=torch.int64
)
self.mrope_positions = torch.zeros(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does PP need mrope or not? if not allocate the buffer when is_multimodal is True

def replay_prepare(
self,
forward_batch: ForwardBatch,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you do not use pp_proxy_tensors please remove it

with enable_piecewise_cuda_graph(), disable_ca_comm(self.model_runner.tp_group):
self.model_runner.attn_backend.init_forward_metadata(forward_batch)
static_forward_batch = self.replay_prepare(forward_batch, **kwargs)
# Extract pp_proxy_tensors from kwargs if present (avoid in-place modification)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With original replay_prepare there is no need to split pp_proxy_tensors from kwargs

# Note: piecewise captures with bs=1, but we need buffer for PP proxy tensors
# The buffer size is 1 since we capture with batch_size=1
self.pp_proxy_tensors_buffer = {
"hidden_states": torch.zeros(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why u need self.pp_proxy_tensors_buffer and why the shape is [1, hidden_dim] instead of ['num_tokens, hidden_dim] here

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed

@baonudesifeizhai
Copy link
Copy Markdown
Contributor Author

baonudesifeizhai commented Dec 20, 2025

 python -m sglang.launch_server \
    --model-path Qwen/Qwen2.5-7B-Instruct \
    --enable-piecewise-cuda-graph \
    --pipeline-parallel-size 2 \
    --tp 4 \
    --trust-remote-code \
    --port 30000
image

python test/srt/test_piecewise_cuda_graph_2_gpu.py TestPiecewiseCudaGraphWithPP pass

@Oasis-Git
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@baonudesifeizhai
Copy link
Copy Markdown
Contributor Author

works normal after fix conflicts

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants