Fix num_items_in_batch GA for Gemma4#4998
Conversation
There was a problem hiding this comment.
Code Review
This pull request updates the patch_gradient_accumulation_fix in unsloth/models/_utils.py to correctly identify the accepts_loss_kwargs flag in wrapped models, specifically addressing issues with Gemma2. However, the use of direct string replacement for patching the source code is fragile and likely to cause indentation errors. It is recommended to use a regular expression that captures and preserves the existing indentation levels to ensure the patched code remains syntactically valid.
| init_function = init_function.replace( | ||
| "self.model_accepts_loss_kwargs = unwrapped_model.accepts_loss_kwargs\n else:", | ||
| "self.model_accepts_loss_kwargs = unwrapped_model.accepts_loss_kwargs\n" | ||
| ' elif hasattr(getattr(unwrapped_model, "model", None), "accepts_loss_kwargs"):\n' | ||
| " self.model_accepts_loss_kwargs = unwrapped_model.model.accepts_loss_kwargs\n" | ||
| " else:", | ||
| ) |
There was a problem hiding this comment.
The implementation using init_function.replace with hardcoded strings is highly likely to fail and cause an IndentationError.
- Indentation Mismatch:
textwrap.dedentremoves common leading whitespace, but the lineself.model_accepts_loss_kwargs = ...remains indented by 8 spaces (relative to the dedenteddef). Your search string starts at column 0, so it won't match. - Indentation Error: Even if it did match, the replacement string would insert code with 0 or 4 spaces of indentation where 8 spaces are expected, breaking the Python syntax.
- Robustness: Using
re.subis much safer for patching source code as it can handle varying whitespace and capture existing indentation levels.
Additionally, consider if unwrapped_model.model is sufficient for PEFT-wrapped models, where the attribute might be nested deeper (e.g., unwrapped_model.model.model).
| init_function = init_function.replace( | |
| "self.model_accepts_loss_kwargs = unwrapped_model.accepts_loss_kwargs\n else:", | |
| "self.model_accepts_loss_kwargs = unwrapped_model.accepts_loss_kwargs\n" | |
| ' elif hasattr(getattr(unwrapped_model, "model", None), "accepts_loss_kwargs"):\n' | |
| " self.model_accepts_loss_kwargs = unwrapped_model.model.accepts_loss_kwargs\n" | |
| " else:", | |
| ) | |
| init_function = re.sub( | |
| r'(\s+self\.model_accepts_loss_kwargs = unwrapped_model\.accepts_loss_kwargs\n)(\s+)(else:)', | |
| r'\1\2elif hasattr(getattr(unwrapped_model, "model", None), "accepts_loss_kwargs"):\n' | |
| r'\2 self.model_accepts_loss_kwargs = unwrapped_model.model.accepts_loss_kwargs\n' | |
| r'\2\3', | |
| init_function, | |
| ) |
Fixes: #4982
The argument is not in Gemma4ForConditionalGeneration, It is in the
.modelof the same