Restore config use_cache in for_inference after gradient checkpointing prep#6137
Conversation
…g prep unsloth_zoo's prepare_model_for_training now sets use_cache=False on the model config and nested sub-configs when gradient checkpointing is on (unsloth-zoo PR 715) and records the original values. Wire the counterpart into both for_inference implementations so the original values come back for inference, and re-disable in for_training when a record exists so resumed training keeps the config consistent. Both calls import lazily and tolerate older unsloth_zoo without the helpers, so version skew in either direction is a no-op. The MLX shims in __init__.py are deliberately untouched.
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Code Review
This pull request introduces changes to unsloth/models/llama.py and unsloth/models/vision.py to restore or disable use_cache settings when transitioning between inference and training modes using helper functions from unsloth_zoo.training_utils. The review feedback recommends improving the exception handling around these optional imports. Specifically, instead of wrapping both the import and the function call in a broad try-except ImportError block—which can mask internal errors during execution—the reviewer suggests catching ModuleNotFoundError specifically for the unsloth_zoo module and conditionally calling the functions only if they were successfully imported.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| try: | ||
| from unsloth_zoo.training_utils import restore_use_cache | ||
| restore_use_cache(model) | ||
| except ImportError: | ||
| pass |
There was a problem hiding this comment.
Swallowing ImportError across both the import statement and the function call can mask legitimate ImportErrors raised inside the restore_use_cache function during its execution. To prevent this, only wrap the import statement in the try-except block, and then conditionally call the function if it was successfully imported. Additionally, when catching an import error for an optional dependency, prefer catching the more specific ModuleNotFoundError and check the module name to avoid suppressing unrelated import errors.
| try: | |
| from unsloth_zoo.training_utils import restore_use_cache | |
| restore_use_cache(model) | |
| except ImportError: | |
| pass | |
| try: | |
| from unsloth_zoo.training_utils import restore_use_cache | |
| except ModuleNotFoundError as e: | |
| if e.name == "unsloth_zoo": | |
| restore_use_cache = None | |
| else: | |
| raise | |
| if restore_use_cache is not None: | |
| restore_use_cache(model) |
References
- When catching an ImportError for an optional dependency, prefer catching the more specific ModuleNotFoundError and check the module name to avoid suppressing unrelated import errors.
| use_gradient_checkpointing | ||
| and getattr(model, "_unsloth_use_cache_originals", None) is not None | ||
| ): | ||
| try: | ||
| from unsloth_zoo.training_utils import disable_use_cache |
There was a problem hiding this comment.
Swallowing ImportError across both the import statement and the function call can mask legitimate ImportErrors raised inside the disable_use_cache function during its execution. To prevent this, only wrap the import statement in the try-except block, and then conditionally call the function if it was successfully imported. Additionally, when catching an import error for an optional dependency, prefer catching the more specific ModuleNotFoundError and check the module name to avoid suppressing unrelated import errors.
| use_gradient_checkpointing | |
| and getattr(model, "_unsloth_use_cache_originals", None) is not None | |
| ): | |
| try: | |
| from unsloth_zoo.training_utils import disable_use_cache | |
| try: | |
| from unsloth_zoo.training_utils import disable_use_cache | |
| except ModuleNotFoundError as e: | |
| if e.name == "unsloth_zoo": | |
| disable_use_cache = None | |
| else: | |
| raise | |
| if disable_use_cache is not None: | |
| disable_use_cache(model) |
References
- When catching an ImportError for an optional dependency, prefer catching the more specific ModuleNotFoundError and check the module name to avoid suppressing unrelated import errors.
| try: | ||
| from unsloth_zoo.training_utils import restore_use_cache | ||
| restore_use_cache(model) | ||
| except ImportError: | ||
| pass |
There was a problem hiding this comment.
Swallowing ImportError across both the import statement and the function call can mask legitimate ImportErrors raised inside the restore_use_cache function during its execution. To prevent this, only wrap the import statement in the try-except block, and then conditionally call the function if it was successfully imported. Additionally, when catching an import error for an optional dependency, prefer catching the more specific ModuleNotFoundError and check the module name to avoid suppressing unrelated import errors.
| try: | |
| from unsloth_zoo.training_utils import restore_use_cache | |
| restore_use_cache(model) | |
| except ImportError: | |
| pass | |
| try: | |
| from unsloth_zoo.training_utils import restore_use_cache | |
| except ModuleNotFoundError as e: | |
| if e.name == "unsloth_zoo": | |
| restore_use_cache = None | |
| else: | |
| raise | |
| if restore_use_cache is not None: | |
| restore_use_cache(model) |
References
- When catching an ImportError for an optional dependency, prefer catching the more specific ModuleNotFoundError and check the module name to avoid suppressing unrelated import errors.
| use_gradient_checkpointing | ||
| and getattr(model, "_unsloth_use_cache_originals", None) is not None | ||
| ): | ||
| try: | ||
| from unsloth_zoo.training_utils import disable_use_cache |
There was a problem hiding this comment.
Swallowing ImportError across both the import statement and the function call can mask legitimate ImportErrors raised inside the disable_use_cache function during its execution. To prevent this, only wrap the import statement in the try-except block, and then conditionally call the function if it was successfully imported. Additionally, when catching an import error for an optional dependency, prefer catching the more specific ModuleNotFoundError and check the module name to avoid suppressing unrelated import errors.
| use_gradient_checkpointing | |
| and getattr(model, "_unsloth_use_cache_originals", None) is not None | |
| ): | |
| try: | |
| from unsloth_zoo.training_utils import disable_use_cache | |
| try: | |
| from unsloth_zoo.training_utils import disable_use_cache | |
| except ModuleNotFoundError as e: | |
| if e.name == "unsloth_zoo": | |
| disable_use_cache = None | |
| else: | |
| raise | |
| if disable_use_cache is not None: | |
| disable_use_cache(model) |
References
- When catching an ImportError for an optional dependency, prefer catching the more specific ModuleNotFoundError and check the module name to avoid suppressing unrelated import errors.
Companion to unslothai/unsloth-zoo#715, which sets use_cache=False on the model config and nested sub-configs (VLM text_config etc.) during training preparation when gradient checkpointing is enabled, and records the original values on the model.
This wires the counterpart into unsloth:
Verified on a B200 with zoo at unsloth-zoo#715 head merged with main:
The zoo side ships its own CPU tests (tests/test_training_utils_use_cache.py, 14 cases) covering the record/restore cycle, falsy preservation and double-prepare semantics.