fix post_residual_addition more generally#17286
fix post_residual_addition more generally#17286Kangyan-Zhou merged 5 commits intosgl-project:mainfrom
Conversation
Summary of ChangesHello @nanjiangwill, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses an issue where Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request adds support for post_residual_addition to various normalization forward passes for non-GPU backends, ensuring consistent behavior across different hardware platforms. The changes involve adding the same logic block to multiple methods in RMSNorm and GemmaRMSNorm. While this fixes the issue, it introduces significant code duplication. My review includes a suggestion to refactor this duplicated logic to improve code maintainability.
| if post_residual_addition is not None: | ||
| residual = residual + post_residual_addition |
There was a problem hiding this comment.
This logic for handling post_residual_addition is now duplicated in at least 10 different methods across RMSNorm and GemmaRMSNorm. This introduces a significant maintainability issue.
To avoid this, I suggest centralizing the logic.
For GemmaRMSNorm, since all its forward_* methods now use the same logic, you could override the main forward method to handle post_residual_addition once. The specific forward_* methods would then not need to handle it. This would look something like:
# In GemmaRMSNorm
def forward(self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, post_residual_addition: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if residual is not None and post_residual_addition is not None:
residual = residual + post_residual_addition
# Assuming self._forward_method is the platform-specific implementation
# Pass post_residual_addition=None as it has been handled.
return self._forward_method(x, residual, post_residual_addition=None)After this, you can remove the duplicated blocks you've added to GemmaRMSNorm's methods.
For RMSNorm, it's more complex as forward_native has different logic. However, the other forward_* methods in RMSNorm could be refactored to call a common private helper function to reduce duplication.
Consolidating this logic will make the code cleaner and easier to maintain.
|
/tag-and-rerun-ci |
|
/tag-and-rerun-ci |
1 similar comment
|
/tag-and-rerun-ci |
Motivation
fix
post_residual_additionaddition for non-GPU backend#16561 only fixed cuda backend
Modifications
Accuracy Tests
Benchmarking and Profiling
Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci