Skip to content

Fix grad-accum accepts_loss_kwargs detection for vision wrappers#5036

Merged
danielhanchen merged 4 commits into
mainfrom
fix/grad-accum-loss-kwargs-unified
Apr 15, 2026
Merged

Fix grad-accum accepts_loss_kwargs detection for vision wrappers#5036
danielhanchen merged 4 commits into
mainfrom
fix/grad-accum-loss-kwargs-unified

Conversation

@danielhanchen

Copy link
Copy Markdown
Member

Summary

Replaces the Trainer.__init__ source-string rewrite for accepts_loss_kwargs detection with an instance-level shadow applied on the loaded model. This is version agnostic, never silently no-ops, and unifies the coverage of #4998 (stock forward fallback via inner .model) and #5030 (Unsloth-compiled forward via signature inspection).

Also backports the transformers 5.6.0.dev0 GradientAccumulationPlugin(num_steps=1) pin to 5.0-5.5 by clamping accelerator.gradient_accumulation_steps = 1 post-init. No-op on transformers 4.x and 5.6+.

Closes #4982. Supersedes #5030.

Why the previous approach was fragile

patch_gradient_accumulation_fix has been rewriting Trainer.__init__ source via str.replace / re.sub. I checked every transformers release from 4.47.0 through 5.5.4 and every TRL release from 0.22.2 through 1.1.0:

  • transformers 4.48 through 4.52 name the parameter model (not unwrapped_model), so both known patch variants silently no-op on those versions.
  • PR Fix num_items_in_batch GA for Gemma4 #4998 catches Gemma3nForConditionalGeneration full-finetune via its inner .model.accepts_loss_kwargs = False fallback, but misses it under PEFT because peft_model.model returns LoraModel (which has no attribute), not the inner Gemma3nModel.
  • PR revert ga fix #5030 (the revert) is correct only when Unsloth's compile pipeline installed the fused-CE forward, and reintroduces [Question] High Gradient Norm and Loss During Initial Training with Gemma-4-E2B #4982 under trust_remote_code=True, UNSLOTH_COMPILE_DISABLE=1, or a compile-cache write failure (reproduces on Colab's ephemeral filesystem).

What this PR does

Adds a single helper apply_accepts_loss_kwargs_fix(model) in unsloth/models/_utils.py and calls it unconditionally from both loaders (unsloth/models/llama.py and unsloth/models/vision.py) after the model is loaded. The helper:

  1. Checks type(model).forward.__code__.co_filename against unsloth_compiled_cache/. If the compiled forward is in use, shadows accepts_loss_kwargs = True on every wrapper level. Unsloth's compiled forward calls unsloth_fixed_cross_entropy which accepts num_items_in_batch, so Trainer must not divide again.
  2. Otherwise walks .base_model / .model looking for the first class that declares accepts_loss_kwargs on its own class dict (uses type(m).__dict__ to avoid PEFT's __getattr__ forwarding returning false positives), and shadows that value. This catches Gemma3nForConditionalGeneration, PaliGemma, Qwen-VL family, and every other vision wrapper where only the inner .model carries the flag.
  3. Otherwise leaves HF's default signature inspection untouched. Text LMs (Llama, Mistral, Qwen3, etc.) fall here and HF correctly infers True from **kwargs in forward.

The removed source-string __init__ rewrite is no longer needed. The training_step rewrite and the compute_loss fix are unchanged.

Scenario matrix (Gemma3nForConditionalGeneration full-FT)

Scenario Expected Stock HF #4998 #5030 This PR
Stock forward (trust_remote_code or compile skipped) False True, bug False, correct True, bug False, correct
Unsloth-compiled forward True True, correct False (still works via dual-mode CE) True, correct True, correct

Verification

End-to-end on transformers 4.57.6 + TRL 0.25.1 + accelerate 1.13.0 + H100, unsloth/gemma-3-270m-it + LoRA:

model type: Gemma3ForCausalLM
instance 'accepts_loss_kwargs' = True
type.forward co_filename: unsloth_compiled_cache/unsloth_compiled_module_gemma3.py

=== Trainer state AFTER init ===
trainer.model_accepts_loss_kwargs = True
trainer.accelerator.gradient_accumulation_steps = 1

=== Walk chain ===
  PeftModelForCausalLM: instance_attr=False   (forwards via __getattr__)
  base_model->LoraModel: instance_attr=False  (forwards via __getattr__)
  base_model->Gemma3ForCausalLM: instance_attr=True, val=True
  model->Gemma3TextModel: instance_attr=True, val=True

=== Training step ===
{'loss': 4.9626, 'grad_norm': 14915175424.0, 'learning_rate': 0.0, 'epoch': 0.25}

Source inspection coverage:

  • transformers 4.47.0 through 5.5.4: new helper works on all (no source rewrite dependency).
  • TRL 0.22.2 through 1.1.0: SFTTrainer inherits from Trainer / BaseTrainer(Trainer) / _BaseTrainer(Trainer) and never overrides model_accepts_loss_kwargs. All RL / preference trainers (GRPO, DPO, KTO, CPO, ORPO, RLOO, BCO, SDFT, async-GRPO) explicitly set self.model_accepts_loss_kwargs = False post super-init, so this fix is a no-op for them (no regression surface).

Test plan

  • Import check: from unsloth.models._utils import apply_accepts_loss_kwargs_fix
  • End-to-end SFT step on Gemma-3 270m + LoRA matches pre-PR loss and grad-norm
  • Scenario matrix (mock Gemma3n, Gemma3, Llama, compiled) produces correct values
  • Run against Gemma-3n-E2B + LoRA with trust_remote_code=True on Colab to confirm [Question] High Gradient Norm and Loss During Initial Training with Gemma-4-E2B #4982 no longer reproduces
  • Run against Gemma-3-4B vision + LoRA to confirm no regression on the compiled path

Replace the source-string rewrite of Trainer.__init__ with an instance-level
accepts_loss_kwargs shadow applied on the loaded model. Covers:

  1. Unsloth-compiled forward -> True, so HF Trainer does not double-scale
     on top of unsloth_fixed_cross_entropy's num_items_in_batch division.
  2. Stock forward on a conditional-generation wrapper (Gemma3n, Gemma3
     pre-4.57, Qwen-VL family, etc.) where the outer class has no
     accepts_loss_kwargs but the inner .model declares False -> False.
     This is the case that reproduces issue #4982 under trust_remote_code
     or UNSLOTH_COMPILE_DISABLE, where the previous fix's outer-attr
     check walked past the inner model and fell through to signature
     inspection.
  3. Text LMs without any explicit accepts_loss_kwargs -> leave HF default.

The previous .replace()-based patch silently no-ops on transformers 4.48
through 4.52 (variable named model, not unwrapped_model) and is fragile
against any upstream reformat. The new helper walks the PEFT / HF wrapper
chain, finds the first class that declares accepts_loss_kwargs on its own
class dict (type(m).__dict__, not hasattr, to avoid PEFT __getattr__
forwarding), and setattr-shadows that value at every wrapper level so
HF Trainer's hasattr(unwrapped_model, ...) check picks it up at whichever
level accelerate.unwrap_model returns.

Also adds an unconditional post-init clamp of
accelerator.gradient_accumulation_steps = 1 to work around the
transformers 5.0 through 5.5 GradientAccumulationPlugin regression that
makes accelerator.backward divide loss by GA on top of training_step's
own /GA division. Fixed upstream in 5.6.0.dev0; no-op on 4.x and 5.6+.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the gradient accumulation fix to be more robust and version-agnostic. Instead of patching the Trainer.__init__ source code via string replacement, it now sets the accepts_loss_kwargs attribute directly on model instances and their wrapper chains. It also includes a fix for a gradient scaling mismatch in Transformers versions 5.0 through 5.5 by clamping the accelerator's accumulation steps. The review feedback suggests replacing several silent exception handlers with logged warnings to improve the maintainability and debuggability of the patching logic.

I am having trouble creating individual review comments. Click here to see my feedback.

unsloth/models/_utils.py (2123-2134)

medium

These broad, silent exception handlers (except Exception: pass) can hide important issues during runtime. It's better to log these exceptions, even at a debug or warning level, to aid in future debugging. This aligns with the general rule to avoid silent exception handling.

            try:
                accelerator = getattr(self, "accelerator", None)
                if accelerator is not None and getattr(accelerator, "gradient_accumulation_steps", 1) > 1:
                    accelerator.gradient_accumulation_steps = 1
                    gs = getattr(accelerator, "gradient_state", None)
                    if gs is not None and hasattr(gs, "plugin_kwargs"):
                        try:
                            gs.plugin_kwargs["num_steps"] = 1
                        except Exception as e:
                            logger.warning_once(f"Unsloth: Failed to set plugin_kwargs['num_steps'] for accelerator: {e}")
            except Exception as e:
                logger.warning_once(f"Unsloth: Failed to clamp accelerator.gradient_accumulation_steps: {e}")
References
  1. Avoid using broad, silent exception handlers like except Exception: pass. Instead, log the exception, even if at a debug level, to aid in future debugging.

unsloth/models/_utils.py (2218-2221)

medium

This try...except block silently ignores any exceptions during setattr. While this might be intended to handle objects that don't allow setting new attributes, it's better to log the exception at a debug or warning level. This helps in debugging unexpected behavior where the attribute isn't set on a particular wrapper.

        try:
            setattr(m, "accepts_loss_kwargs", value)
        except Exception as e:
            logger.warning_once(f"Unsloth: Could not shadow accepts_loss_kwargs on {type(m).__name__}: {e}")
References
  1. Avoid using broad, silent exception handlers like except Exception: pass. Instead, log the exception, even if at a debug level, to aid in future debugging.

Two review findings from 3/20 reviewers:

1. [3 of 20 reviewers] apply_accepts_loss_kwargs_fix was called from the
   loaders before get_peft_model wraps the base model, so on transformers
   4.48-4.52 (which does hasattr on the outer model) the instance shadow
   on the base model was lost after PEFT wrapping. Fix: also call it from
   the wrapped Trainer.__init__ so it runs on whatever model the user
   actually hands to Trainer, which is always the final wrapped form.

2. [1 of 20 reviewers] _forward_is_unsloth_compiled hard-coded the
   substrings "unsloth_compiled" / "unsloth_cache" in the co_filename
   check, which misclassifies compiled forwards when
   UNSLOTH_COMPILE_LOCATION is set to a custom directory. Fix: new
   _unsloth_compile_cache_leaves helper that reads the env var and
   matches the basename against path components, honoring both the
   default and any user override.

Verified locally:
- PEFT-after-load simulation: HF's hasattr(peft, "accepts_loss_kwargs")
  now returns True after our init wrapper runs, and value resolves to
  False on Gemma3n-style inner wrappers.
- Custom UNSLOTH_COMPILE_LOCATION simulation: compiled detection returns
  True for /tmp/my_custom_cache/compiled.py when the env var is set.
- End-to-end Gemma-3 270m + LoRA SFT unchanged: loss 4.9626, grad-norm
  matches prior run, all 4 wrapper levels now carry the shadowed attr.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 1cf231c562

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread unsloth/models/_utils.py
Comment on lines +2095 to +2099
model = kwargs.get("model")
if model is None and len(args) > 0:
model = args[0]
if model is not None:
try:

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Handle model_init when applying loss-kwargs fix

This wrapper only calls apply_accepts_loss_kwargs_fix when a model object is present in constructor args, so Trainer(model_init=...) skips the fix entirely because the model is instantiated inside _original_trainer_init. In that path, model_accepts_loss_kwargs is still inferred from the unpatched model and vision wrappers that require accepts_loss_kwargs=False can regress to incorrect loss scaling during gradient accumulation; the fix needs a post-init call against self.model as well.

Useful? React with 👍 / 👎.

@danielhanchen danielhanchen merged commit 1a4ca5e into main Apr 15, 2026
5 checks passed
@danielhanchen danielhanchen deleted the fix/grad-accum-loss-kwargs-unified branch April 15, 2026 13:59
danielhanchen added a commit that referenced this pull request May 9, 2026
Extend the cross-version compat canary to catch ~80% of upstream
drift before a user hits it. Static checks only (GitHub raw fetch +
grep), CPU-only, runs PR-time + daily cron. 906 pass, 73 skipped.

TRL coverage extended:
- TRL_TAGS expanded from 12 to 28 (every stable release >=0.18.2,
  including the broken 0.19.0, plus main). Anchors: 0.22.2 / 0.27.1
  / 1.0.0 marked.
- Fix `__version__` parser to handle the TRL 0.22.x pattern
  (`__version__ = f.read()` from sibling VERSION file).
- Fix `has_def` in _fetch.py to allow indented matches so class
  methods are detected (the original anchored ^def only matched
  module-scope definitions).
- New tests for symbols the audit found we touch but didn't check:
  is_conversational, sft_trainer module + neftune_post_forward_hook,
  dpo_trainer module + MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES,
  trl.trainer.utils.ConstantLengthDataset (gated),
  trl.models.utils.disable_gradient_checkpointing (gated >=1.0.0),
  trl.import_utils + _*_available cache pattern,
  trl.experimental.openenv.utils generators (one of two names),
  GRPOTrainer required methods (_prepare_inputs,
  _generate_and_score_completions, compute_loss; per-token-logps
  legacy/new dispatch), GRPOTrainer source must contain
  torch.inference_mode + accelerator.unwrap_model fingerprints,
  KTOTrainer.get_batch_logps (now lives at trl.experimental.kto
  on TRL 0.27+ — accept either path),
  SFTTrainer class existence, DPOTrainer methods (informational),
  chat-template propagation (legacy maybe_apply_chat_template OR
  successor apply_chat_template + chat_template_kwargs),
  truncate_with_protected_tokens informational.
- Tighten test_unwrap_model_for_generation_either_path to mirror
  the prod fallback exactly (drop unused trl/extras/profiling.py
  candidate).
- Replace test_trl_generation_vllm_generation_gated symbol set with
  the actual unsloth dependency (VLLMGeneration class + _init_vllm
  / sync_weights / generate methods, not VLLMClient/etc).

PEFT coverage extended (driven by the 8 PR audit unsloth#5015,
#5167, #5036, #4807 + unsloth-zoo#618, #596, #482, #430):
- VARIANT_KWARG_KEYS const (peft 0.18+; injected by zoo#430)
- ParamWrapper class + members (peft 0.18+; needed by zoo#618)
- LoraConfig.target_parameters (peft 0.19+)
- LoraModel._create_and_replace (signature pin for unsloth#4807)
- transformers_weight_conversion module + build_peft_weight_mapping
  (unsloth#5167 wraps this)
- integrations.dequantize_module_weight (3 callsites)
- PeftType.LORA (vllm_utils.py:2520)
- ModulesToSaveWrapper (both peft.utils.* paths)
- PeftModel.from_pretrained method exists
- peft.__version__ parseable

Transformers coverage added (driven by the 16-PR audit):
- New file test_transformers_pinned_symbols.py with 19 test
  categories x 12 transformers tags (4.57.6 floor + 5.0..5.8 + main).
  Anchors: 4.57.6 + 5.5.0.
- Trainer surface (compute_loss num_items_in_batch param,
  training_step grad-accum fingerprints, get_batch_samples
  num_items contract, inner_training_loop _tr_loss inplace v5)
- modeling_utils.checkpoint alias for unsloth-zoo#549
- PushToHubMixin._create_repo presence (unsloth-zoo#393)
- integrations.bitsandbytes module + Linear4bit reference
- quantizers.should_convert_module signature (zoo#491/#488)
- FP8Linear bias/has_bias rename (zoo#572)
- processing_utils.Unpack importable (zoo#583/584)
- gemma3 Gemma3Attention class + gpt_oss GptOssModel class
- auto_factory _LazyAutoMapping private API (unsloth#5155)
- configuration_utils PretrainedConfig/PreTrainedConfig alias
- tokenization_utils_base.apply_chat_template
- modeling_attn_mask_utils symbols
- cache_utils Cache + DynamicCache classes
- training_args.ParallelMode importable

Wire the new transformers job into version-compat-ci.yml (matrix
of 5 PR-time symbol jobs + zoo-imports under spoof + daily fresh-
fetch cron).

Local smoke: 906 pass, 73 skipped (gated optional features) across
vLLM + TRL + PEFT + ST + bnb + transformers suites.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Question] High Gradient Norm and Loss During Initial Training with Gemma-4-E2B

1 participant