Fix Gemma 4 fused cross entropy not being applied#575
Merged
Conversation
Gemma 4's ForConditionalGeneration.forward() uses a slightly different loss code structure than other VLMs: flat_logits = shift_logits.view(-1, vocab_size) flat_labels = shift_labels.view(-1).to(shift_logits.device) loss = loss_fct(flat_logits, flat_labels) The compiler's cross_entropy_find_3 regex expects shift_logits/shift_labels variable names and the .view(-1) and .to(device) calls on separate lines. None of the 3 patterns match, so the fused cross entropy replacement is never applied. This causes the logged training loss to be inflated by the gradient_accumulation_steps factor (Issue #1 in gemma-4-details). Add source normalization rules to fixup_fused_lm_head() that rename flat_logits/flat_labels back to shift_logits/shift_labels and split the chained .view(-1).to(...) into two lines. After normalization, pattern 3 matches and unsloth_fused_ce_loss is applied correctly. Tested: [3/3 pattern] Successfully patched fast linear cross entropy for Gemma4ForConditionalGeneration
Contributor
There was a problem hiding this comment.
Code Review
This pull request updates the fixup_fused_lm_head function in unsloth_zoo/compiler.py to include specific fixes for Gemma 4. The modifications normalize variable names from flat_logits and flat_labels to shift_logits and shift_labels, and use regex to split chained .view(-1).to(...) operations into separate lines to ensure compatibility with existing pattern matching logic. I have no feedback to provide.
3 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
fixup_fused_lm_head()so the compiler's pattern 3 regex matches Gemma 4'sForConditionalGeneration.forward()Problem
Gemma 4 uses different variable names and a chained
.to()call in its loss code:The compiler's
cross_entropy_find_3regex expectsshift_logits/shift_labelsand.view(-1)and.to(device)on separate lines. None of the 3 patterns match, sounsloth_fused_ce_lossis never applied.Fix
Normalize Gemma 4's forward source in
fixup_fused_lm_head()before the regex runs:flat_logits = shift_logits.view(-1, ...)->shift_logits = shift_logits.view(-1, ...)flat_labels = shift_labels.view(-1).to(device)into two linesloss = loss_fct(flat_logits, flat_labels)->loss = loss_fct(shift_logits, shift_labels)After normalization, pattern 3 matches and applies the fused cross entropy.
Test plan
apply_fused_lm_head()on Gemma 4 forward source returns modified result[3/3 pattern] Successfully patched fast linear cross entropy for Gemma4ForConditionalGeneration