Skip to content

[3/n] DP Enhancement: Padding tokens to max length when workload is balanced#8278

Closed
ch-wan wants to merge 4 commits intogh/ch-wam/3/basefrom
gh/ch-wam/3/head
Closed

[3/n] DP Enhancement: Padding tokens to max length when workload is balanced#8278
ch-wan wants to merge 4 commits intogh/ch-wam/3/basefrom
gh/ch-wam/3/head

Conversation

@ch-wan
Copy link
Copy Markdown
Collaborator

@ch-wan ch-wan commented Jul 23, 2025

Stack from ghstack (oldest at bottom):

Background

Integrating CUDA Graph with DP attention is tricky because different DP ranks have different workload. Simply gathering tensors from all DP ranks would incur dynamic shape. There are two design choices to make DP attention compatible with CUDA Graph:

  • Padding tokens to max(len) and use all-gather to gather tensors. This design pads the input tensor of FFN to max_len * dp_size and incurs redundant computation (see examples in #4390 and response here).
  • Padding tokens to sum(len) and use all-reduce to gather tensors. This design padded tokens to sum(len) which increases computation in the attention layers (#6092 #7951). In addition, the cuda graph batch size is sum(len) rather than the actual running batch size on each DP worker. The users need to increase cuda graph batch size when DP size in creased, which causes a lot of confusion or misuse (#5527 #5557 #6092).

This PR offers two gather modes. When CUDA graph is disabled, we the select the option that minimizes communication. When CUDA graph is in use, we use the first option.

Future Work

Under max_len padding mode, the input hidden_state of FFN is padded to max_len * dp_size, which may incur redundant computation. In the future, we can use masked (gouped-)GeMM kernels (e.g., DeepGEMM) to optimize it.

[ghstack-poisoned]
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @ch-wan, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces significant optimizations for distributed attention and CUDA graph execution by implementing dynamic token gathering strategies and comprehensive token padding. The core goal is to achieve a more balanced workload across GPUs and improve the efficiency of distributed inference by standardizing input tensor shapes and optimizing communication patterns.

Highlights

  • Dynamic Distributed Attention Gathering: Implemented a new mechanism to dynamically choose between all_gather and all_reduce strategies for distributed attention (DP attention). This selection is based on the distribution of token lengths within a batch, aiming to optimize communication efficiency.
  • Token Padding for Workload Balancing: Introduced padding of token sequences to a maximum length within a batch. This ensures uniform input tensor shapes, which is crucial for balanced workload distribution across GPUs and efficient CUDA graph capture, especially in distributed environments.
  • CUDA Graph Optimization for Padded Inputs: Modified the CUDA graph capture and replay logic to correctly handle and leverage these padded token lengths. This includes adjusting batch size calculations to consider the maximum token length and ensuring proper memory allocation for the padded tensors.
  • New Distributed Communication Primitive: Added a reduce_scatter_tensor method to the distributed state, which is a key primitive enabling the new all_gather based distributed attention strategy.
  • Post-Processing for Padded Outputs: Implemented a post-forward pass step to automatically trim outputs from the padded tensors back to their original, unpadded sizes, ensuring that subsequent processing receives data in the expected format.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the token distribution strategy for data parallel attention. Instead of balancing the number of tokens across ranks, it pads all ranks to have the same number of tokens as the rank with the maximum number of tokens. This change aims to create a more balanced workload, which can improve performance, particularly when using CUDA graphs.

The implementation introduces a DPGatherMode to dynamically choose between all_reduce and all_gather for token synchronization. It also adds significant padding and un-padding logic within the ForwardBatch class to handle the new workload distribution. The changes are extensive, affecting CUDA graph capturing, replay logic, and several model execution components.

The code is generally well-structured, but there are a few areas for improvement regarding correctness and maintainability, which I've detailed in the comments below.

Comment on lines +303 to +308
def _dp_gather(
global_tokens: torch.Tensor,
local_tokens: torch.Tensor,
forward_batch: ForwardBatch,
is_partial: bool,
):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The new function _dp_gather is missing a return type hint. For consistency and code clarity, it's good practice to add one. This function appears to have side effects and doesn't return a value, so -> None would be appropriate. The same applies to _dp_gather_via_all_gather defined earlier in this file.

Suggested change
def _dp_gather(
global_tokens: torch.Tensor,
local_tokens: torch.Tensor,
forward_batch: ForwardBatch,
is_partial: bool,
):
def _dp_gather(
global_tokens: torch.Tensor,
local_tokens: torch.Tensor,
forward_batch: ForwardBatch,
is_partial: bool,
) -> None:

Comment on lines +700 to +746
def post_forward_mlp_sync_batch(self, logits_output: LogitsProcessorOutput):

bs = self.batch_size

if self.spec_info is not None:
if self.forward_mode.is_decode(): # draft
num_tokens = self.hidden_states_backup.shape[0]
self.positions = self.positions[:num_tokens]
self.seq_lens = self.seq_lens[:bs]
self.req_pool_indices = self.req_pool_indices[:bs]
if self.seq_lens_cpu is not None:
self.seq_lens_cpu = self.seq_lens_cpu[:bs]
logits_output.next_token_logits = logits_output.next_token_logits[
:num_tokens
]
logits_output.hidden_states = logits_output.hidden_states[:num_tokens]
elif self.forward_mode.is_target_verify(): # verify
num_tokens = bs * self.spec_info.draft_token_num
logits_output.next_token_logits = logits_output.next_token_logits[
:num_tokens
]
logits_output.hidden_states = logits_output.hidden_states[:num_tokens]
elif self.forward_mode.is_draft_extend(): # draft extend
self.spec_info.accept_length = self.spec_info.accept_length[:bs]
logits_output.next_token_logits = logits_output.next_token_logits[:bs]
logits_output.hidden_states = logits_output.hidden_states[:bs]
elif self.forward_mode.is_extend() or self.forward_mode.is_idle():
logits_output.next_token_logits = logits_output.next_token_logits[:bs]
logits_output.hidden_states = logits_output.hidden_states[:bs]

if hasattr(self, "hidden_states_backup"):
self.spec_info.hidden_states = self.hidden_states_backup
if hasattr(self, "output_cache_loc_backup"):
self.out_cache_loc = self.output_cache_loc_backup

elif self.forward_mode.is_decode() or self.forward_mode.is_idle():
logits_output.next_token_logits = logits_output.next_token_logits[:bs]
if logits_output.hidden_states is not None:
logits_output.hidden_states = logits_output.hidden_states[:bs]
elif self.forward_mode.is_extend():
num_tokens = self.seq_lens_sum
logits_output.next_token_logits = logits_output.next_token_logits[
:num_tokens
]
if logits_output.hidden_states is not None:
logits_output.hidden_states = logits_output.hidden_states[:num_tokens]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The post_forward_mlp_sync_batch function restores state that was modified in prepare_mlp_sync_batch by checking for attributes like hidden_states_backup using hasattr. This pattern can be fragile. If an error occurs after prepare_mlp_sync_batch but before post_forward_mlp_sync_batch, the ForwardBatch object could be left in an inconsistent state.

Consider making this state management more robust. For example, prepare_mlp_sync_batch could return a context manager or a callable that restores the state. If that's too much of a refactor, at least adding a comment to prepare_mlp_sync_batch explaining which attributes are being modified and that post_forward_mlp_sync_batch must be called to restore them would improve maintainability.

For example, in prepare_mlp_sync_batch:

# Back up attributes that will be modified in-place.
# These must be restored by calling post_forward_mlp_sync_batch.
self.output_cache_loc_backup = self.out_cache_loc
self.hidden_states_backup = spec_info.hidden_states

@ch-wan ch-wan changed the title Padding tokens to max length for balanced workload [3/n] DP Enhancement: Padding tokens to max length when workload is balanced Jul 23, 2025
@ch-wan ch-wan linked an issue Jul 23, 2025 that may be closed by this pull request
5 tasks
@ch-wan ch-wan mentioned this pull request Jul 23, 2025
6 tasks
ch-wan added 2 commits July 23, 2025 10:23
[ghstack-poisoned]
[ghstack-poisoned]
@ch-wan ch-wan linked an issue Jul 23, 2025 that may be closed by this pull request
Copy link
Copy Markdown
Collaborator

@ispobock ispobock left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

def get_dp_gather_mode(cls, global_num_tokens: List[int]) -> DPGatherMode:
max_len = max(global_num_tokens)
sum_len = sum(global_num_tokens)
if sum_len * 2 > max_len * get_attention_dp_size():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, how did we get this formula?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The communication volume of each worker for AllGather mode and AllReduce mode are max_len * dp_size and sum_len * 2, respectively.

if self.require_mlp_tp_gather:
total_batch_size = (
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
max_num_tokens = max(forward_batch.global_num_tokens_cpu)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use forward_batch.global_num_tokens_cpu.max().item() may be more efficient?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. I'm going to revise this.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lol, global_num_tokens_cpu is a list of int

)
def _pad_tensor_to_size(self, tensor: torch.Tensor, size: int, *, value: int = 0):
if value == 0:
return torch.cat(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: why don't use torch pad function?

Copy link
Copy Markdown
Collaborator Author

@ch-wan ch-wan Jul 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you asking this one (https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.pad.html)? It is for image padding. Unfortunately, I'm not sure how to use them.

ch-wan added a commit that referenced this pull request Jul 25, 2025
[ghstack-poisoned]
@ch-wan
Copy link
Copy Markdown
Collaborator Author

ch-wan commented Jul 25, 2025

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant enhancement to Data Parallelism (DP) in the attention mechanism, making it compatible with CUDA Graph by adding two gather modes. When CUDA Graph is disabled, an optimal communication strategy is chosen, and when it's enabled, an all-gather approach with padding to max(len) is used. The changes are extensive, touching data structures, CUDA graph capture logic, and the forward pass execution flow. The implementation centralizes padding and un-padding logic, which improves code structure. My review focuses on correctness and maintainability. I've identified a few minor issues with type hints and conditional checks that should be addressed for better code quality.

Comment on lines +554 to 555
if hidden_states.shape[0] > 0:
# router_logits: (num_tokens, n_experts)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The condition if is_non_idle_and_non_empty(forward_mode, hidden_states): is removed, and replaced with if hidden_states.shape[0] > 0:. It is important to ensure that this change does not introduce any regressions or unexpected behavior in cases where is_non_idle_and_non_empty previously provided additional filtering.

) -> torch.Tensor:
forward_mode = forward_batch.forward_mode
if is_non_idle_and_non_empty(forward_mode, hidden_states):
if hidden_states.shape[0] > 0:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The condition if is_non_idle_and_non_empty(forward_mode, hidden_states): is removed, and replaced with if hidden_states.shape[0] > 0:. It is important to ensure that this change does not introduce any regressions or unexpected behavior in cases where is_non_idle_and_non_empty previously provided additional filtering.

self,
output: torch.Tensor,
input: torch.Tensor,
) -> None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function reduce_scatter_tensor is type-hinted to return None, but it returns the output tensor on line 509. This should be corrected to -> torch.Tensor for type consistency and better code clarity.

Suggested change
) -> None:
) -> torch.Tensor:

local_tokens: torch.Tensor,
forward_batch: ForwardBatch,
is_partial: bool,
):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function _dp_gather_via_all_gather is missing a return type hint. Since it doesn't return a value, it should be annotated with -> None for better code clarity and to aid static analysis.

Suggested change
):
) -> None:

local_tokens: torch.Tensor,
forward_batch: ForwardBatch,
is_partial: bool,
):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function _dp_gather is missing a return type hint. It should be annotated with -> None as it does not return any value. This improves code readability and helps static analysis tools.

Suggested change
):
) -> None:


def prepare_mlp_sync_batch(self, model_runner: ModelRunner):

from sglang.srt.speculative.eagle_utils import EagleDraftInput
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To avoid potential circular dependencies, it's a common practice to perform local imports within functions. However, for better code readability and maintainability, consider moving this import to the top of the file if it doesn't create a circular dependency. If it does, adding a comment explaining why the local import is necessary would be helpful for future maintainers.

@ch-wan ch-wan closed this Jul 25, 2025
@ch-wan ch-wan deleted the gh/ch-wam/3/head branch July 25, 2025 04:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] Tensor shape is wrong when cudagraph+enable_dp_attention Cuda graph supported bs in DP attention

4 participants