fix: handle zero-strided tensors in fast_rope_embedding (#3781)#4233
fix: handle zero-strided tensors in fast_rope_embedding (#3781)#4233danielhanchen wants to merge 1 commit into
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request resolves an issue where the Highlights
Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request correctly fixes a bug where fast_rope_embedding would fail on zero-strided tensors, which can occur with expanded or broadcast tensors. The fix, which involves cloning the tensor if it's not contiguous or contains a zero in its strides, is applied to both the forward and backward passes and seems robust. My feedback includes a suggestion to refactor the duplicated logic into a helper method to improve code maintainability.
| Q_out = Q.clone() if not Q.is_contiguous() or 0 in Q.stride() else Q | ||
| K_out = K.clone() if not K.is_contiguous() or 0 in K.stride() else K |
There was a problem hiding this comment.
The logic to clone the tensor if it's not contiguous or has a zero stride is duplicated here and in the backward pass for dQ and dK (lines 387-388). To improve maintainability and adhere to the Don't Repeat Yourself (DRY) principle, consider extracting this logic into a helper function or a static method within the Fast_RoPE_Embedding_QK class.
For example:
@staticmethod
def _clone_if_needed(tensor):
# Clone if not contiguous or has zero strides, such as expanded tensors.
if not tensor.is_contiguous() or 0 in tensor.stride():
return tensor.clone()
return tensorYou could then call _clone_if_needed(Q) and _clone_if_needed(K) here, and similarly in the backward pass.
Replacement for #3804 due to Studio rebasing
Summary
Fix #3781: Handle zero-strided tensors in
fast_rope_embeddingforward and backward passes.When gradient tensors (
dQ,dK) have zero strides (e.g., from expanded/broadcast tensors during debugging scenarios like(out[0].sum() + out[1].sum()).backward()), the triton kernel fails because all stride values become zero, causing incorrect memory access patterns.Changes
Code Changes
Test plan
🤖 Generated with Claude Code