[3/n] DP Enhancement: Padding tokens to max length when workload is balanced#8278
[3/n] DP Enhancement: Padding tokens to max length when workload is balanced#8278ch-wan wants to merge 4 commits intogh/ch-wam/3/basefrom
Conversation
There was a problem hiding this comment.
Summary of Changes
Hello @ch-wan, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces significant optimizations for distributed attention and CUDA graph execution by implementing dynamic token gathering strategies and comprehensive token padding. The core goal is to achieve a more balanced workload across GPUs and improve the efficiency of distributed inference by standardizing input tensor shapes and optimizing communication patterns.
Highlights
- Dynamic Distributed Attention Gathering: Implemented a new mechanism to dynamically choose between
all_gatherandall_reducestrategies for distributed attention (DP attention). This selection is based on the distribution of token lengths within a batch, aiming to optimize communication efficiency. - Token Padding for Workload Balancing: Introduced padding of token sequences to a maximum length within a batch. This ensures uniform input tensor shapes, which is crucial for balanced workload distribution across GPUs and efficient CUDA graph capture, especially in distributed environments.
- CUDA Graph Optimization for Padded Inputs: Modified the CUDA graph capture and replay logic to correctly handle and leverage these padded token lengths. This includes adjusting batch size calculations to consider the maximum token length and ensuring proper memory allocation for the padded tensors.
- New Distributed Communication Primitive: Added a
reduce_scatter_tensormethod to the distributed state, which is a key primitive enabling the newall_gatherbased distributed attention strategy. - Post-Processing for Padded Outputs: Implemented a post-forward pass step to automatically trim outputs from the padded tensors back to their original, unpadded sizes, ensuring that subsequent processing receives data in the expected format.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Code Review
This pull request refactors the token distribution strategy for data parallel attention. Instead of balancing the number of tokens across ranks, it pads all ranks to have the same number of tokens as the rank with the maximum number of tokens. This change aims to create a more balanced workload, which can improve performance, particularly when using CUDA graphs.
The implementation introduces a DPGatherMode to dynamically choose between all_reduce and all_gather for token synchronization. It also adds significant padding and un-padding logic within the ForwardBatch class to handle the new workload distribution. The changes are extensive, affecting CUDA graph capturing, replay logic, and several model execution components.
The code is generally well-structured, but there are a few areas for improvement regarding correctness and maintainability, which I've detailed in the comments below.
| def _dp_gather( | ||
| global_tokens: torch.Tensor, | ||
| local_tokens: torch.Tensor, | ||
| forward_batch: ForwardBatch, | ||
| is_partial: bool, | ||
| ): |
There was a problem hiding this comment.
The new function _dp_gather is missing a return type hint. For consistency and code clarity, it's good practice to add one. This function appears to have side effects and doesn't return a value, so -> None would be appropriate. The same applies to _dp_gather_via_all_gather defined earlier in this file.
| def _dp_gather( | |
| global_tokens: torch.Tensor, | |
| local_tokens: torch.Tensor, | |
| forward_batch: ForwardBatch, | |
| is_partial: bool, | |
| ): | |
| def _dp_gather( | |
| global_tokens: torch.Tensor, | |
| local_tokens: torch.Tensor, | |
| forward_batch: ForwardBatch, | |
| is_partial: bool, | |
| ) -> None: |
| def post_forward_mlp_sync_batch(self, logits_output: LogitsProcessorOutput): | ||
|
|
||
| bs = self.batch_size | ||
|
|
||
| if self.spec_info is not None: | ||
| if self.forward_mode.is_decode(): # draft | ||
| num_tokens = self.hidden_states_backup.shape[0] | ||
| self.positions = self.positions[:num_tokens] | ||
| self.seq_lens = self.seq_lens[:bs] | ||
| self.req_pool_indices = self.req_pool_indices[:bs] | ||
| if self.seq_lens_cpu is not None: | ||
| self.seq_lens_cpu = self.seq_lens_cpu[:bs] | ||
| logits_output.next_token_logits = logits_output.next_token_logits[ | ||
| :num_tokens | ||
| ] | ||
| logits_output.hidden_states = logits_output.hidden_states[:num_tokens] | ||
| elif self.forward_mode.is_target_verify(): # verify | ||
| num_tokens = bs * self.spec_info.draft_token_num | ||
| logits_output.next_token_logits = logits_output.next_token_logits[ | ||
| :num_tokens | ||
| ] | ||
| logits_output.hidden_states = logits_output.hidden_states[:num_tokens] | ||
| elif self.forward_mode.is_draft_extend(): # draft extend | ||
| self.spec_info.accept_length = self.spec_info.accept_length[:bs] | ||
| logits_output.next_token_logits = logits_output.next_token_logits[:bs] | ||
| logits_output.hidden_states = logits_output.hidden_states[:bs] | ||
| elif self.forward_mode.is_extend() or self.forward_mode.is_idle(): | ||
| logits_output.next_token_logits = logits_output.next_token_logits[:bs] | ||
| logits_output.hidden_states = logits_output.hidden_states[:bs] | ||
|
|
||
| if hasattr(self, "hidden_states_backup"): | ||
| self.spec_info.hidden_states = self.hidden_states_backup | ||
| if hasattr(self, "output_cache_loc_backup"): | ||
| self.out_cache_loc = self.output_cache_loc_backup | ||
|
|
||
| elif self.forward_mode.is_decode() or self.forward_mode.is_idle(): | ||
| logits_output.next_token_logits = logits_output.next_token_logits[:bs] | ||
| if logits_output.hidden_states is not None: | ||
| logits_output.hidden_states = logits_output.hidden_states[:bs] | ||
| elif self.forward_mode.is_extend(): | ||
| num_tokens = self.seq_lens_sum | ||
| logits_output.next_token_logits = logits_output.next_token_logits[ | ||
| :num_tokens | ||
| ] | ||
| if logits_output.hidden_states is not None: | ||
| logits_output.hidden_states = logits_output.hidden_states[:num_tokens] | ||
|
|
There was a problem hiding this comment.
The post_forward_mlp_sync_batch function restores state that was modified in prepare_mlp_sync_batch by checking for attributes like hidden_states_backup using hasattr. This pattern can be fragile. If an error occurs after prepare_mlp_sync_batch but before post_forward_mlp_sync_batch, the ForwardBatch object could be left in an inconsistent state.
Consider making this state management more robust. For example, prepare_mlp_sync_batch could return a context manager or a callable that restores the state. If that's too much of a refactor, at least adding a comment to prepare_mlp_sync_batch explaining which attributes are being modified and that post_forward_mlp_sync_batch must be called to restore them would improve maintainability.
For example, in prepare_mlp_sync_batch:
# Back up attributes that will be modified in-place.
# These must be restored by calling post_forward_mlp_sync_batch.
self.output_cache_loc_backup = self.out_cache_loc
self.hidden_states_backup = spec_info.hidden_states| def get_dp_gather_mode(cls, global_num_tokens: List[int]) -> DPGatherMode: | ||
| max_len = max(global_num_tokens) | ||
| sum_len = sum(global_num_tokens) | ||
| if sum_len * 2 > max_len * get_attention_dp_size(): |
There was a problem hiding this comment.
Just curious, how did we get this formula?
There was a problem hiding this comment.
The communication volume of each worker for AllGather mode and AllReduce mode are max_len * dp_size and sum_len * 2, respectively.
| if self.require_mlp_tp_gather: | ||
| total_batch_size = ( | ||
| sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs | ||
| max_num_tokens = max(forward_batch.global_num_tokens_cpu) |
There was a problem hiding this comment.
nit: use forward_batch.global_num_tokens_cpu.max().item() may be more efficient?
There was a problem hiding this comment.
Good catch. I'm going to revise this.
There was a problem hiding this comment.
lol, global_num_tokens_cpu is a list of int
| ) | ||
| def _pad_tensor_to_size(self, tensor: torch.Tensor, size: int, *, value: int = 0): | ||
| if value == 0: | ||
| return torch.cat( |
There was a problem hiding this comment.
nit: why don't use torch pad function?
There was a problem hiding this comment.
Are you asking this one (https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.pad.html)? It is for image padding. Unfortunately, I'm not sure how to use them.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces a significant enhancement to Data Parallelism (DP) in the attention mechanism, making it compatible with CUDA Graph by adding two gather modes. When CUDA Graph is disabled, an optimal communication strategy is chosen, and when it's enabled, an all-gather approach with padding to max(len) is used. The changes are extensive, touching data structures, CUDA graph capture logic, and the forward pass execution flow. The implementation centralizes padding and un-padding logic, which improves code structure. My review focuses on correctness and maintainability. I've identified a few minor issues with type hints and conditional checks that should be addressed for better code quality.
| if hidden_states.shape[0] > 0: | ||
| # router_logits: (num_tokens, n_experts) |
There was a problem hiding this comment.
The condition if is_non_idle_and_non_empty(forward_mode, hidden_states): is removed, and replaced with if hidden_states.shape[0] > 0:. It is important to ensure that this change does not introduce any regressions or unexpected behavior in cases where is_non_idle_and_non_empty previously provided additional filtering.
| ) -> torch.Tensor: | ||
| forward_mode = forward_batch.forward_mode | ||
| if is_non_idle_and_non_empty(forward_mode, hidden_states): | ||
| if hidden_states.shape[0] > 0: |
There was a problem hiding this comment.
The condition if is_non_idle_and_non_empty(forward_mode, hidden_states): is removed, and replaced with if hidden_states.shape[0] > 0:. It is important to ensure that this change does not introduce any regressions or unexpected behavior in cases where is_non_idle_and_non_empty previously provided additional filtering.
| self, | ||
| output: torch.Tensor, | ||
| input: torch.Tensor, | ||
| ) -> None: |
| local_tokens: torch.Tensor, | ||
| forward_batch: ForwardBatch, | ||
| is_partial: bool, | ||
| ): |
| local_tokens: torch.Tensor, | ||
| forward_batch: ForwardBatch, | ||
| is_partial: bool, | ||
| ): |
|
|
||
| def prepare_mlp_sync_batch(self, model_runner: ModelRunner): | ||
|
|
||
| from sglang.srt.speculative.eagle_utils import EagleDraftInput |
There was a problem hiding this comment.
To avoid potential circular dependencies, it's a common practice to perform local imports within functions. However, for better code readability and maintainability, consider moving this import to the top of the file if it doesn't create a circular dependency. If it does, adding a comment explaining why the local import is necessary would be helpful for future maintainers.
Stack from ghstack (oldest at bottom):
dp < tpby usingall_gather_into_tensorandreduce_scatter_tensor#8279Background
Integrating CUDA Graph with DP attention is tricky because different DP ranks have different workload. Simply gathering tensors from all DP ranks would incur dynamic shape. There are two design choices to make DP attention compatible with CUDA Graph:
max(len)and use all-gather to gather tensors. This design pads the input tensor of FFN tomax_len * dp_sizeand incurs redundant computation (see examples in #4390 and response here).sum(len)and use all-reduce to gather tensors. This design padded tokens tosum(len)which increases computation in the attention layers (#6092 #7951). In addition, the cuda graph batch size issum(len)rather than the actual running batch size on each DP worker. The users need to increase cuda graph batch size when DP size in creased, which causes a lot of confusion or misuse (#5527 #5557 #6092).This PR offers two gather modes. When CUDA graph is disabled, we the select the option that minimizes communication. When CUDA graph is in use, we use the first option.
Future Work
Under
max_lenpadding mode, the input hidden_state of FFN is padded tomax_len * dp_size, which may incur redundant computation. In the future, we can use masked (gouped-)GeMM kernels (e.g., DeepGEMM) to optimize it.