Fix grad-accum accepts_loss_kwargs detection for vision wrappers#5036
Conversation
Replace the source-string rewrite of Trainer.__init__ with an instance-level
accepts_loss_kwargs shadow applied on the loaded model. Covers:
1. Unsloth-compiled forward -> True, so HF Trainer does not double-scale
on top of unsloth_fixed_cross_entropy's num_items_in_batch division.
2. Stock forward on a conditional-generation wrapper (Gemma3n, Gemma3
pre-4.57, Qwen-VL family, etc.) where the outer class has no
accepts_loss_kwargs but the inner .model declares False -> False.
This is the case that reproduces issue #4982 under trust_remote_code
or UNSLOTH_COMPILE_DISABLE, where the previous fix's outer-attr
check walked past the inner model and fell through to signature
inspection.
3. Text LMs without any explicit accepts_loss_kwargs -> leave HF default.
The previous .replace()-based patch silently no-ops on transformers 4.48
through 4.52 (variable named model, not unwrapped_model) and is fragile
against any upstream reformat. The new helper walks the PEFT / HF wrapper
chain, finds the first class that declares accepts_loss_kwargs on its own
class dict (type(m).__dict__, not hasattr, to avoid PEFT __getattr__
forwarding), and setattr-shadows that value at every wrapper level so
HF Trainer's hasattr(unwrapped_model, ...) check picks it up at whichever
level accelerate.unwrap_model returns.
Also adds an unconditional post-init clamp of
accelerator.gradient_accumulation_steps = 1 to work around the
transformers 5.0 through 5.5 GradientAccumulationPlugin regression that
makes accelerator.backward divide loss by GA on top of training_step's
own /GA division. Fixed upstream in 5.6.0.dev0; no-op on 4.x and 5.6+.
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Code Review
This pull request refactors the gradient accumulation fix to be more robust and version-agnostic. Instead of patching the Trainer.__init__ source code via string replacement, it now sets the accepts_loss_kwargs attribute directly on model instances and their wrapper chains. It also includes a fix for a gradient scaling mismatch in Transformers versions 5.0 through 5.5 by clamping the accelerator's accumulation steps. The review feedback suggests replacing several silent exception handlers with logged warnings to improve the maintainability and debuggability of the patching logic.
I am having trouble creating individual review comments. Click here to see my feedback.
unsloth/models/_utils.py (2123-2134)
These broad, silent exception handlers (except Exception: pass) can hide important issues during runtime. It's better to log these exceptions, even at a debug or warning level, to aid in future debugging. This aligns with the general rule to avoid silent exception handling.
try:
accelerator = getattr(self, "accelerator", None)
if accelerator is not None and getattr(accelerator, "gradient_accumulation_steps", 1) > 1:
accelerator.gradient_accumulation_steps = 1
gs = getattr(accelerator, "gradient_state", None)
if gs is not None and hasattr(gs, "plugin_kwargs"):
try:
gs.plugin_kwargs["num_steps"] = 1
except Exception as e:
logger.warning_once(f"Unsloth: Failed to set plugin_kwargs['num_steps'] for accelerator: {e}")
except Exception as e:
logger.warning_once(f"Unsloth: Failed to clamp accelerator.gradient_accumulation_steps: {e}")
References
- Avoid using broad, silent exception handlers like
except Exception: pass. Instead, log the exception, even if at a debug level, to aid in future debugging.
unsloth/models/_utils.py (2218-2221)
This try...except block silently ignores any exceptions during setattr. While this might be intended to handle objects that don't allow setting new attributes, it's better to log the exception at a debug or warning level. This helps in debugging unexpected behavior where the attribute isn't set on a particular wrapper.
try:
setattr(m, "accepts_loss_kwargs", value)
except Exception as e:
logger.warning_once(f"Unsloth: Could not shadow accepts_loss_kwargs on {type(m).__name__}: {e}")
References
- Avoid using broad, silent exception handlers like
except Exception: pass. Instead, log the exception, even if at a debug level, to aid in future debugging.
Two review findings from 3/20 reviewers: 1. [3 of 20 reviewers] apply_accepts_loss_kwargs_fix was called from the loaders before get_peft_model wraps the base model, so on transformers 4.48-4.52 (which does hasattr on the outer model) the instance shadow on the base model was lost after PEFT wrapping. Fix: also call it from the wrapped Trainer.__init__ so it runs on whatever model the user actually hands to Trainer, which is always the final wrapped form. 2. [1 of 20 reviewers] _forward_is_unsloth_compiled hard-coded the substrings "unsloth_compiled" / "unsloth_cache" in the co_filename check, which misclassifies compiled forwards when UNSLOTH_COMPILE_LOCATION is set to a custom directory. Fix: new _unsloth_compile_cache_leaves helper that reads the env var and matches the basename against path components, honoring both the default and any user override. Verified locally: - PEFT-after-load simulation: HF's hasattr(peft, "accepts_loss_kwargs") now returns True after our init wrapper runs, and value resolves to False on Gemma3n-style inner wrappers. - Custom UNSLOTH_COMPILE_LOCATION simulation: compiled detection returns True for /tmp/my_custom_cache/compiled.py when the env var is set. - End-to-end Gemma-3 270m + LoRA SFT unchanged: loss 4.9626, grad-norm matches prior run, all 4 wrapper levels now carry the shadowed attr.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 1cf231c562
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| model = kwargs.get("model") | ||
| if model is None and len(args) > 0: | ||
| model = args[0] | ||
| if model is not None: | ||
| try: |
There was a problem hiding this comment.
Handle
model_init when applying loss-kwargs fix
This wrapper only calls apply_accepts_loss_kwargs_fix when a model object is present in constructor args, so Trainer(model_init=...) skips the fix entirely because the model is instantiated inside _original_trainer_init. In that path, model_accepts_loss_kwargs is still inferred from the unpatched model and vision wrappers that require accepts_loss_kwargs=False can regress to incorrect loss scaling during gradient accumulation; the fix needs a post-init call against self.model as well.
Useful? React with 👍 / 👎.
Extend the cross-version compat canary to catch ~80% of upstream drift before a user hits it. Static checks only (GitHub raw fetch + grep), CPU-only, runs PR-time + daily cron. 906 pass, 73 skipped. TRL coverage extended: - TRL_TAGS expanded from 12 to 28 (every stable release >=0.18.2, including the broken 0.19.0, plus main). Anchors: 0.22.2 / 0.27.1 / 1.0.0 marked. - Fix `__version__` parser to handle the TRL 0.22.x pattern (`__version__ = f.read()` from sibling VERSION file). - Fix `has_def` in _fetch.py to allow indented matches so class methods are detected (the original anchored ^def only matched module-scope definitions). - New tests for symbols the audit found we touch but didn't check: is_conversational, sft_trainer module + neftune_post_forward_hook, dpo_trainer module + MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, trl.trainer.utils.ConstantLengthDataset (gated), trl.models.utils.disable_gradient_checkpointing (gated >=1.0.0), trl.import_utils + _*_available cache pattern, trl.experimental.openenv.utils generators (one of two names), GRPOTrainer required methods (_prepare_inputs, _generate_and_score_completions, compute_loss; per-token-logps legacy/new dispatch), GRPOTrainer source must contain torch.inference_mode + accelerator.unwrap_model fingerprints, KTOTrainer.get_batch_logps (now lives at trl.experimental.kto on TRL 0.27+ — accept either path), SFTTrainer class existence, DPOTrainer methods (informational), chat-template propagation (legacy maybe_apply_chat_template OR successor apply_chat_template + chat_template_kwargs), truncate_with_protected_tokens informational. - Tighten test_unwrap_model_for_generation_either_path to mirror the prod fallback exactly (drop unused trl/extras/profiling.py candidate). - Replace test_trl_generation_vllm_generation_gated symbol set with the actual unsloth dependency (VLLMGeneration class + _init_vllm / sync_weights / generate methods, not VLLMClient/etc). PEFT coverage extended (driven by the 8 PR audit unsloth#5015, #5167, #5036, #4807 + unsloth-zoo#618, #596, #482, #430): - VARIANT_KWARG_KEYS const (peft 0.18+; injected by zoo#430) - ParamWrapper class + members (peft 0.18+; needed by zoo#618) - LoraConfig.target_parameters (peft 0.19+) - LoraModel._create_and_replace (signature pin for unsloth#4807) - transformers_weight_conversion module + build_peft_weight_mapping (unsloth#5167 wraps this) - integrations.dequantize_module_weight (3 callsites) - PeftType.LORA (vllm_utils.py:2520) - ModulesToSaveWrapper (both peft.utils.* paths) - PeftModel.from_pretrained method exists - peft.__version__ parseable Transformers coverage added (driven by the 16-PR audit): - New file test_transformers_pinned_symbols.py with 19 test categories x 12 transformers tags (4.57.6 floor + 5.0..5.8 + main). Anchors: 4.57.6 + 5.5.0. - Trainer surface (compute_loss num_items_in_batch param, training_step grad-accum fingerprints, get_batch_samples num_items contract, inner_training_loop _tr_loss inplace v5) - modeling_utils.checkpoint alias for unsloth-zoo#549 - PushToHubMixin._create_repo presence (unsloth-zoo#393) - integrations.bitsandbytes module + Linear4bit reference - quantizers.should_convert_module signature (zoo#491/#488) - FP8Linear bias/has_bias rename (zoo#572) - processing_utils.Unpack importable (zoo#583/584) - gemma3 Gemma3Attention class + gpt_oss GptOssModel class - auto_factory _LazyAutoMapping private API (unsloth#5155) - configuration_utils PretrainedConfig/PreTrainedConfig alias - tokenization_utils_base.apply_chat_template - modeling_attn_mask_utils symbols - cache_utils Cache + DynamicCache classes - training_args.ParallelMode importable Wire the new transformers job into version-compat-ci.yml (matrix of 5 PR-time symbol jobs + zoo-imports under spoof + daily fresh- fetch cron). Local smoke: 906 pass, 73 skipped (gated optional features) across vLLM + TRL + PEFT + ST + bnb + transformers suites.
Summary
Replaces the
Trainer.__init__source-string rewrite foraccepts_loss_kwargsdetection with an instance-level shadow applied on the loaded model. This is version agnostic, never silently no-ops, and unifies the coverage of #4998 (stock forward fallback via inner.model) and #5030 (Unsloth-compiled forward via signature inspection).Also backports the transformers 5.6.0.dev0
GradientAccumulationPlugin(num_steps=1)pin to 5.0-5.5 by clampingaccelerator.gradient_accumulation_steps = 1post-init. No-op on transformers 4.x and 5.6+.Closes #4982. Supersedes #5030.
Why the previous approach was fragile
patch_gradient_accumulation_fixhas been rewritingTrainer.__init__source viastr.replace/re.sub. I checked every transformers release from 4.47.0 through 5.5.4 and every TRL release from 0.22.2 through 1.1.0:model(notunwrapped_model), so both known patch variants silently no-op on those versions..model.accepts_loss_kwargs = Falsefallback, but misses it under PEFT becausepeft_model.modelreturnsLoraModel(which has no attribute), not the innerGemma3nModel.trust_remote_code=True,UNSLOTH_COMPILE_DISABLE=1, or a compile-cache write failure (reproduces on Colab's ephemeral filesystem).What this PR does
Adds a single helper
apply_accepts_loss_kwargs_fix(model)inunsloth/models/_utils.pyand calls it unconditionally from both loaders (unsloth/models/llama.pyandunsloth/models/vision.py) after the model is loaded. The helper:type(model).forward.__code__.co_filenameagainstunsloth_compiled_cache/. If the compiled forward is in use, shadowsaccepts_loss_kwargs = Trueon every wrapper level. Unsloth's compiled forward callsunsloth_fixed_cross_entropywhich acceptsnum_items_in_batch, so Trainer must not divide again..base_model/.modellooking for the first class that declaresaccepts_loss_kwargson its own class dict (usestype(m).__dict__to avoid PEFT's__getattr__forwarding returning false positives), and shadows that value. This catches Gemma3nForConditionalGeneration, PaliGemma, Qwen-VL family, and every other vision wrapper where only the inner.modelcarries the flag.**kwargsinforward.The removed source-string
__init__rewrite is no longer needed. Thetraining_steprewrite and thecompute_lossfix are unchanged.Scenario matrix (Gemma3nForConditionalGeneration full-FT)
Verification
End-to-end on transformers 4.57.6 + TRL 0.25.1 + accelerate 1.13.0 + H100,
unsloth/gemma-3-270m-it+ LoRA:Source inspection coverage:
Trainer/BaseTrainer(Trainer)/_BaseTrainer(Trainer)and never overridesmodel_accepts_loss_kwargs. All RL / preference trainers (GRPO, DPO, KTO, CPO, ORPO, RLOO, BCO, SDFT, async-GRPO) explicitly setself.model_accepts_loss_kwargs = Falsepost super-init, so this fix is a no-op for them (no regression surface).Test plan
from unsloth.models._utils import apply_accepts_loss_kwargs_fixtrust_remote_code=Trueon Colab to confirm [Question] High Gradient Norm and Loss During Initial Training with Gemma-4-E2B #4982 no longer reproduces