Skip to content

Fix Gemma 4 fused cross entropy not being applied#575

Merged
danielhanchen merged 1 commit into
mainfrom
fix/gemma4-fused-cross-entropy
Apr 2, 2026
Merged

Fix Gemma 4 fused cross entropy not being applied#575
danielhanchen merged 1 commit into
mainfrom
fix/gemma4-fused-cross-entropy

Conversation

@danielhanchen

Copy link
Copy Markdown
Member

Summary

  • Add source normalization in fixup_fused_lm_head() so the compiler's pattern 3 regex matches Gemma 4's ForConditionalGeneration.forward()
  • Without this, logged training loss is inflated by ~gradient_accumulation_steps (e.g. loss shows ~52 instead of ~13 with GA=4)
  • Training itself is correct (gradients are fine), only the logged metric is wrong

Problem

Gemma 4 uses different variable names and a chained .to() call in its loss code:

flat_logits = shift_logits.view(-1, self.config.get_text_config().vocab_size)
flat_labels = shift_labels.view(-1).to(shift_logits.device)
loss = loss_fct(flat_logits, flat_labels)

The compiler's cross_entropy_find_3 regex expects shift_logits/shift_labels and .view(-1) and .to(device) on separate lines. None of the 3 patterns match, so unsloth_fused_ce_loss is never applied.

Fix

Normalize Gemma 4's forward source in fixup_fused_lm_head() before the regex runs:

  1. flat_logits = shift_logits.view(-1, ...) -> shift_logits = shift_logits.view(-1, ...)
  2. Split flat_labels = shift_labels.view(-1).to(device) into two lines
  3. loss = loss_fct(flat_logits, flat_labels) -> loss = loss_fct(shift_logits, shift_labels)

After normalization, pattern 3 matches and applies the fused cross entropy.

Test plan

  • apply_fused_lm_head() on Gemma 4 forward source returns modified result
  • Log output: [3/3 pattern] Successfully patched fast linear cross entropy for Gemma4ForConditionalGeneration
  • Finetune Gemma 4 E2B-it with GA>1 and confirm loss is no longer inflated
  • Verify other VLMs (Gemma 3, Llama 4, Qwen2.5-VL) still patch correctly

Gemma 4's ForConditionalGeneration.forward() uses a slightly different
loss code structure than other VLMs:

  flat_logits = shift_logits.view(-1, vocab_size)
  flat_labels = shift_labels.view(-1).to(shift_logits.device)
  loss = loss_fct(flat_logits, flat_labels)

The compiler's cross_entropy_find_3 regex expects shift_logits/shift_labels
variable names and the .view(-1) and .to(device) calls on separate lines.
None of the 3 patterns match, so the fused cross entropy replacement is
never applied. This causes the logged training loss to be inflated by
the gradient_accumulation_steps factor (Issue #1 in gemma-4-details).

Add source normalization rules to fixup_fused_lm_head() that rename
flat_logits/flat_labels back to shift_logits/shift_labels and split the
chained .view(-1).to(...) into two lines. After normalization, pattern 3
matches and unsloth_fused_ce_loss is applied correctly.

Tested: [3/3 pattern] Successfully patched fast linear cross entropy
for Gemma4ForConditionalGeneration

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the fixup_fused_lm_head function in unsloth_zoo/compiler.py to include specific fixes for Gemma 4. The modifications normalize variable names from flat_logits and flat_labels to shift_logits and shift_labels, and use regex to split chained .view(-1).to(...) operations into separate lines to ensure compatibility with existing pattern matching logic. I have no feedback to provide.

@danielhanchen danielhanchen merged commit c33e848 into main Apr 2, 2026
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant