Skip to content

Logging cleanup#715

Merged
danielhanchen merged 3 commits into
mainfrom
logging-cleanup-2
Jun 10, 2026
Merged

Logging cleanup#715
danielhanchen merged 3 commits into
mainfrom
logging-cleanup-2

Conversation

@danielhanchen

Copy link
Copy Markdown
Member

Logging cleanup

@danielhanchen danielhanchen mentioned this pull request Jun 3, 2026

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request disables the KV cache across all configuration objects, including nested sub-configs, when gradient checkpointing is enabled. The review feedback suggests removing the redundant try-except block around the import of PretrainedConfig, as transformers is a required dependency, and simplifying the configuration traversal.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread unsloth_zoo/training_utils.py Outdated
Comment on lines +228 to +244
try:
from transformers import PretrainedConfig
except Exception:
PretrainedConfig = None
_seen = set()
_stack = [model.config]
while _stack:
_cfg = _stack.pop()
if _cfg is None or id(_cfg) in _seen:
continue
_seen.add(id(_cfg))
if getattr(_cfg, "use_cache", None):
_cfg.use_cache = False
if PretrainedConfig is not None:
for _sub in vars(_cfg).values():
if isinstance(_sub, PretrainedConfig):
_stack.append(_sub)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Since transformers is a required dependency of the project, wrapping its import in a try...except block is redundant. We can import PretrainedConfig directly. Additionally, we can simplify the traversal and use hasattr to safely disable use_cache.

Suggested change
try:
from transformers import PretrainedConfig
except Exception:
PretrainedConfig = None
_seen = set()
_stack = [model.config]
while _stack:
_cfg = _stack.pop()
if _cfg is None or id(_cfg) in _seen:
continue
_seen.add(id(_cfg))
if getattr(_cfg, "use_cache", None):
_cfg.use_cache = False
if PretrainedConfig is not None:
for _sub in vars(_cfg).values():
if isinstance(_sub, PretrainedConfig):
_stack.append(_sub)
from transformers import PretrainedConfig
_seen = set()
_stack = [model.config]
while _stack:
_cfg = _stack.pop()
if _cfg is None or id(_cfg) in _seen:
continue
_seen.add(id(_cfg))
if hasattr(_cfg, "use_cache"):
_cfg.use_cache = False
for _sub in vars(_cfg).values():
if isinstance(_sub, PretrainedConfig):
_stack.append(_sub)
References
  1. Avoid using try...except ImportError for libraries that are required dependencies of the project, as the check is redundant.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switched the import to the PreTrainedConfig first pattern used elsewhere in the repo in 85a07af, which also removes the redundant guard. Kept the truthy getattr check rather than hasattr on purpose: hasattr would coerce use_cache=None to False, and None means defer to the model default, so flipping it would be a semantic change. This is pinned by tests/test_training_utils_use_cache.py::test_falsy_use_cache_preserved.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 7b23cbf0f9

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread unsloth_zoo/training_utils.py Outdated
# KV cache is unused under gradient checkpointing; disable it on every config.
if use_gradient_checkpointing in (True, "unsloth") and getattr(model, "config", None) is not None:
try:
from transformers import PretrainedConfig

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Fall back to PreTrainedConfig for nested cache disabling

In Transformers 4.57+/5.x environments supported by this repo, the config class may be exported as PreTrainedConfig rather than PretrainedConfig (the existing helpers already use that fallback). With only this legacy import, the except path sets PretrainedConfig = None, so the loop never descends into composite configs such as text_config; those nested decoder configs can keep use_cache=True under gradient checkpointing, leaving the warning/cache behavior this block is meant to suppress for VLM/composite models.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 85a07af. The block now tries PreTrainedConfig first and falls back to the legacy PretrainedConfig name, matching hf_utils.py, patching_utils.py and empty_model.py, and the None fallback that silently skipped nested traversal is gone. Added tests/test_training_utils_use_cache.py which covers the nested composite config case (Gemma3Config text_config) on CPU.

Align the gradient checkpointing use_cache block with the repo's
PreTrainedConfig-first import convention (hf_utils, patching_utils,
empty_model) and drop the dead None fallback that skipped nested
config traversal. Add CPU tests pinning the contract: top-level and
nested composite configs flip to False under gradient checkpointing,
None/False values are preserved, non-config attachments are ignored,
and self-referencing config graphs terminate.
@danielhanchen

Copy link
Copy Markdown
Member Author

Validation summary for this PR (now includes 85a07af: PreTrainedConfig-first import plus CPU tests).

What was checked

  • Without this PR, nothing in unsloth or unsloth_zoo sets config.use_cache = False for training. The existing code only hides the HF warning via logger filters, and transformers itself (modeling_layers.py GradientCheckpointingLayer) pops use_cache from forward kwargs per layer at runtime without fixing configs. So this block is complementary, not redundant.
  • Composite configs are the real target: on unsloth/gemma-3-4b-it, Gemma3Config has no top level use_cache at all, and text_config.use_cache stays True through load, PEFT and training on main. Only the nested traversal here reaches it.

Loss parity, 61 steps, LoRA 4 bit, bsz 2 x ga 3, seed 3407, B200

Model Method Steps/s Peak mem Losses [1,2,3,...,60,61] Grad norms [1,2,3,...,60,61]
Qwen2.5-0.5B-Instruct main 1.276 1.206 GB [1.597, 1.747, 1.854, ..., 1.503, 1.546] [2.277, 2.697, 3.715, ..., 1.127, 1.148]
Qwen2.5-0.5B-Instruct PR 1.278 1.206 GB [1.597, 1.747, 1.855, ..., 1.501, 1.545] [2.278, 2.697, 3.715, ..., 1.118, 1.153]
gemma-3-4b-it main 0.540 5.656 GB [1.811, 1.959, 2.176, ..., 1.150, 1.094] [3.668, 3.745, nan, ..., 0.802, 0.757]
gemma-3-4b-it PR 0.552 5.656 GB identical to main (max abs diff 0.0) identical to main

Gemma losses and grad norms are bit identical across all 61 steps; even the nan grad norm log positions match (steps 3, 7, 10, 13, 18, 39 on both, a pre-existing logging artifact in 4 bit gemma, not from this PR). The Qwen deltas (max 0.004) are within same-install run to run noise (two identical main runs already differ by 0.001 at step 3). Config snapshots confirm use_cache flips to False at get_peft_model (kbit path) for Qwen and at from_pretrained (vision path) for gemma, and stays False after training.

Tests

tests/test_training_utils_use_cache.py (CPU only, runs in 0.04s): top level and nested composite flips for both True and "unsloth", use_cache=None and False preserved, untouched when checkpointing is off, non config attachments ignored, self referencing config graphs terminate.

Known trade-off

The flipped flag persists after training: for_inference does not restore it, and a saved config will carry use_cache=False. In practice generate() reads generation_config.use_cache (still True, untouched here) and HF Trainer has long done the same top level mutation, so only raw model.forward() calls relying on the config default are affected. If that matters we can restore the prior values in for_inference as a follow-up in unsloth.

from types import SimpleNamespace

import pytest
import torch
Factor the gradient checkpointing use_cache walk into disable_use_cache
and remember which configs were flipped on the model the first time, so
restore_use_cache can put the original values back when switching to
inference. Falsy values (None/False) are never recorded, so restore
cannot invent True on configs that never had it. The record survives
restore so for_training can re-disable without re-recording. Both
helpers are exported for unsloth's for_inference / for_training to call;
older unsloth versions simply never call them, keeping behavior
unchanged.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 517249691d

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +281 to +282
if use_gradient_checkpointing in (True, "unsloth"):
disable_use_cache(model)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Restore cache when checkpointing is disabled

When the same model is first prepared with gradient checkpointing and later passed through prepare_model_for_training(..., use_gradient_checkpointing=False), the false path disables the module checkpointing flags above but this guard skips both disable_use_cache and restore_use_cache, leaving every previously touched config at use_cache=False. That makes subsequent non-checkpointed training or inference/generation on the same model run without KV cache even though checkpointing was explicitly turned off; call restore_use_cache(model) in the non-checkpointing path to undo the mutation recorded by the earlier prepare call.

Useful? React with 👍 / 👎.

@danielhanchen

Copy link
Copy Markdown
Member Author

Update: the inference trade-off noted above is now resolved instead of documented.

5172496 factors the walk into disable_use_cache / restore_use_cache (both exported). The first disable records the original truthy values on the model; restore_use_cache puts them back, and the record survives so a later disable flips them again without re-recording. None/False values are never recorded, so a restore cannot invent True on configs that never had it. Test coverage grew to 14 CPU cases including the full disable -> restore -> disable cycle, nested composite restore, and double-prepare semantics.

unslothai/unsloth#6137 wires the unsloth side: for_inference restores, for_training re-disables when a record exists, both behind lazy imports so any version pairing of unsloth and unsloth_zoo degrades to current behavior.

GPU round trip on a B200 (zoo at this PR merged with main, unsloth at #6137): Qwen2.5-0.5B kbit path and gemma-3-4b-it vision path both cycle False -> True -> False -> True across for_inference / for_training, generate works after restore, and the nested text_config follows correctly.

if record and originals:
try:
model._unsloth_use_cache_originals = originals
except Exception:
@danielhanchen danielhanchen merged commit 73ac733 into main Jun 10, 2026
15 checks passed
@danielhanchen danielhanchen deleted the logging-cleanup-2 branch June 10, 2026 12:44
danielhanchen added a commit that referenced this pull request Jun 11, 2026
…gate (#755)

test_mlx_finetune_last_n_layers was born broken in #669 and stayed
invisible until #739 because no CI job executed it: the version matrix
only collects, the macOS MLX job runs the shim smoke test alone, and
the zoo-specific CPU list does not include it. Add a small hard-gate
step in repo-tests-cpu running it together with
test_training_utils_use_cache (the use_cache disable/restore contract
from #715). Both files are CPU-pure and run in under a second, and the
job already installs the deps they need.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant