Skip to content

Restore config use_cache in for_inference after gradient checkpointing prep#6137

Merged
danielhanchen merged 2 commits into
mainfrom
restore-use-cache-for-inference
Jun 10, 2026
Merged

Restore config use_cache in for_inference after gradient checkpointing prep#6137
danielhanchen merged 2 commits into
mainfrom
restore-use-cache-for-inference

Conversation

@danielhanchen

Copy link
Copy Markdown
Member

Companion to unslothai/unsloth-zoo#715, which sets use_cache=False on the model config and nested sub-configs (VLM text_config etc.) during training preparation when gradient checkpointing is enabled, and records the original values on the model.

This wires the counterpart into unsloth:

  • FastLlamaModel.for_inference and FastBaseModel.for_inference call unsloth_zoo.training_utils.restore_use_cache, putting the recorded original values back so inference sees use_cache=True again (or whatever the model shipped with). Values that were None or False before training are never recorded, so restore cannot invent True on configs that never had it.
  • Both for_training implementations re-disable via disable_use_cache when a record exists and gradient checkpointing is requested, so train -> infer -> train round trips keep the config consistent without re-recording.
  • Both calls import lazily inside the function and swallow ImportError, so new unsloth with old unsloth_zoo (no helpers) and old unsloth with new unsloth_zoo both degrade to current behavior. The MLX no-op shims in init.py are untouched, keeping Apple Silicon isolated.

Verified on a B200 with zoo at unsloth-zoo#715 head merged with main:

  • Qwen2.5-0.5B (kbit path): use_cache False after get_peft_model, True after for_inference, generate works, False again after for_training, True after a second for_inference.
  • gemma-3-4b-it (vision path): nested text_config.use_cache False after load, True after for_inference, False after for_training.
  • Old-zoo guard: deleting the helpers from unsloth_zoo.training_utils and calling for_inference / for_training raises nothing.

The zoo side ships its own CPU tests (tests/test_training_utils_use_cache.py, 14 cases) covering the record/restore cycle, falsy preservation and double-prepare semantics.

…g prep

unsloth_zoo's prepare_model_for_training now sets use_cache=False on the
model config and nested sub-configs when gradient checkpointing is on
(unsloth-zoo PR 715) and records the original values. Wire the
counterpart into both for_inference implementations so the original
values come back for inference, and re-disable in for_training when a
record exists so resumed training keeps the config consistent. Both
calls import lazily and tolerate older unsloth_zoo without the helpers,
so version skew in either direction is a no-op. The MLX shims in
__init__.py are deliberately untouched.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces changes to unsloth/models/llama.py and unsloth/models/vision.py to restore or disable use_cache settings when transitioning between inference and training modes using helper functions from unsloth_zoo.training_utils. The review feedback recommends improving the exception handling around these optional imports. Specifically, instead of wrapping both the import and the function call in a broad try-except ImportError block—which can mask internal errors during execution—the reviewer suggests catching ModuleNotFoundError specifically for the unsloth_zoo module and conditionally calling the functions only if they were successfully imported.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread unsloth/models/llama.py
Comment on lines +3473 to +3477
try:
from unsloth_zoo.training_utils import restore_use_cache
restore_use_cache(model)
except ImportError:
pass

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Swallowing ImportError across both the import statement and the function call can mask legitimate ImportErrors raised inside the restore_use_cache function during its execution. To prevent this, only wrap the import statement in the try-except block, and then conditionally call the function if it was successfully imported. Additionally, when catching an import error for an optional dependency, prefer catching the more specific ModuleNotFoundError and check the module name to avoid suppressing unrelated import errors.

Suggested change
try:
from unsloth_zoo.training_utils import restore_use_cache
restore_use_cache(model)
except ImportError:
pass
try:
from unsloth_zoo.training_utils import restore_use_cache
except ModuleNotFoundError as e:
if e.name == "unsloth_zoo":
restore_use_cache = None
else:
raise
if restore_use_cache is not None:
restore_use_cache(model)
References
  1. When catching an ImportError for an optional dependency, prefer catching the more specific ModuleNotFoundError and check the module name to avoid suppressing unrelated import errors.

Comment thread unsloth/models/llama.py
Comment on lines +3529 to +3533
use_gradient_checkpointing
and getattr(model, "_unsloth_use_cache_originals", None) is not None
):
try:
from unsloth_zoo.training_utils import disable_use_cache

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Swallowing ImportError across both the import statement and the function call can mask legitimate ImportErrors raised inside the disable_use_cache function during its execution. To prevent this, only wrap the import statement in the try-except block, and then conditionally call the function if it was successfully imported. Additionally, when catching an import error for an optional dependency, prefer catching the more specific ModuleNotFoundError and check the module name to avoid suppressing unrelated import errors.

Suggested change
use_gradient_checkpointing
and getattr(model, "_unsloth_use_cache_originals", None) is not None
):
try:
from unsloth_zoo.training_utils import disable_use_cache
try:
from unsloth_zoo.training_utils import disable_use_cache
except ModuleNotFoundError as e:
if e.name == "unsloth_zoo":
disable_use_cache = None
else:
raise
if disable_use_cache is not None:
disable_use_cache(model)
References
  1. When catching an ImportError for an optional dependency, prefer catching the more specific ModuleNotFoundError and check the module name to avoid suppressing unrelated import errors.

Comment thread unsloth/models/vision.py
Comment on lines +1711 to +1715
try:
from unsloth_zoo.training_utils import restore_use_cache
restore_use_cache(model)
except ImportError:
pass

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Swallowing ImportError across both the import statement and the function call can mask legitimate ImportErrors raised inside the restore_use_cache function during its execution. To prevent this, only wrap the import statement in the try-except block, and then conditionally call the function if it was successfully imported. Additionally, when catching an import error for an optional dependency, prefer catching the more specific ModuleNotFoundError and check the module name to avoid suppressing unrelated import errors.

Suggested change
try:
from unsloth_zoo.training_utils import restore_use_cache
restore_use_cache(model)
except ImportError:
pass
try:
from unsloth_zoo.training_utils import restore_use_cache
except ModuleNotFoundError as e:
if e.name == "unsloth_zoo":
restore_use_cache = None
else:
raise
if restore_use_cache is not None:
restore_use_cache(model)
References
  1. When catching an ImportError for an optional dependency, prefer catching the more specific ModuleNotFoundError and check the module name to avoid suppressing unrelated import errors.

Comment thread unsloth/models/vision.py
Comment on lines +1778 to +1782
use_gradient_checkpointing
and getattr(model, "_unsloth_use_cache_originals", None) is not None
):
try:
from unsloth_zoo.training_utils import disable_use_cache

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Swallowing ImportError across both the import statement and the function call can mask legitimate ImportErrors raised inside the disable_use_cache function during its execution. To prevent this, only wrap the import statement in the try-except block, and then conditionally call the function if it was successfully imported. Additionally, when catching an import error for an optional dependency, prefer catching the more specific ModuleNotFoundError and check the module name to avoid suppressing unrelated import errors.

Suggested change
use_gradient_checkpointing
and getattr(model, "_unsloth_use_cache_originals", None) is not None
):
try:
from unsloth_zoo.training_utils import disable_use_cache
try:
from unsloth_zoo.training_utils import disable_use_cache
except ModuleNotFoundError as e:
if e.name == "unsloth_zoo":
disable_use_cache = None
else:
raise
if disable_use_cache is not None:
disable_use_cache(model)
References
  1. When catching an ImportError for an optional dependency, prefer catching the more specific ModuleNotFoundError and check the module name to avoid suppressing unrelated import errors.

@danielhanchen danielhanchen merged commit 656a087 into main Jun 10, 2026
42 of 44 checks passed
@danielhanchen danielhanchen deleted the restore-use-cache-for-inference branch June 10, 2026 12:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant