fix(mlx): expose finetune_last_n_layers for parity with mlx-lm CLI#669
Conversation
mlx-lm's lora CLI defaults CONFIG_DEFAULTS['num_layers']=16 (mlx_lm/lora.py:56) which trains LoRA only on the last 16 transformer blocks. unsloth-zoo's FastMLXModel.get_peft_model applies LoRA to ALL transformer layers (matching HF/PEFT/CUDA semantics on the GPU path). On small models the difference shows up as a basin-selection divergence -- the extra LoRA modules consume mx.random state during init AND change the trainable-parameter set, so two otherwise-identical runs land in different basins of attraction. Empirical, n=15 seeds, gemma-3-270m-it single-row LoRA memorization fixture: mlx-lm CLI's last-16 hits 67%, training all 18 layers hits 47%. The teacher-forced completion loss is 0 in both, so memorization succeeds either way -- the gap is purely on greedy-decode first-token argmax. This commit adds an opt-in `finetune_last_n_layers` parameter (default None = all layers, current behavior unchanged). Pass `finetune_last_n_layers=16` to mirror mlx-lm CLI exactly. Wired into both the VLM and text-only code paths in get_peft_model. The bound is clamped to [1, len(model.model.layers)] so callers can't accidentally request more layers than the model has, or zero layers (which would freeze everything).
There was a problem hiding this comment.
Code Review
This pull request introduces the finetune_last_n_layers parameter to FastMLXModel.get_peft_model, enabling users to restrict LoRA application to the last N transformer blocks. This change aligns the library's behavior with mlx-lm CLI defaults while maintaining the existing all-layers default for backward compatibility. Accompanying tests verify the parameter's functionality and edge-case handling. The review feedback highlights a potential issue where the requested layer count might be ignored if the model's total layer count isn't detected; suggestions were provided to ensure the user's intent is honored in such cases.
| if finetune_last_n_layers is not None and num_layers > 0: | ||
| num_layers = max(1, min(int(finetune_last_n_layers), num_layers)) |
There was a problem hiding this comment.
The current logic skips updating num_layers if the total layer count detection fails (num_layers == 0). In such cases, num_layers remains 0, which mlx_lm.tuner.utils.linear_to_lora_layers interprets as applying LoRA to all layers. If a user explicitly requested a specific number of layers via finetune_last_n_layers, falling back to all layers is likely unexpected.
It is better to honor the user's request even if the total count is unknown, as mlx-lm's internal slicing (layers[-num_layers:]) is safe in Python even if the requested number exceeds the actual list length.
| if finetune_last_n_layers is not None and num_layers > 0: | |
| num_layers = max(1, min(int(finetune_last_n_layers), num_layers)) | |
| if finetune_last_n_layers is not None: | |
| requested = int(finetune_last_n_layers) | |
| num_layers = max(1, min(requested, num_layers) if num_layers > 0 else requested) |
| if finetune_last_n_layers is not None and num_layers > 0: | ||
| num_layers = max(1, min(int(finetune_last_n_layers), num_layers)) |
There was a problem hiding this comment.
Similar to the VLM path, the num_layers > 0 check prevents finetune_last_n_layers from being applied if the model's layer count wasn't successfully detected. This results in a fallback to 'all layers' (since num_layers remains 0), which contradicts the user's intent to limit the fine-tuning scope.
Updating num_layers to the requested value regardless of detection success ensures that mlx-lm attempts to slice the layers as requested.
| if finetune_last_n_layers is not None and num_layers > 0: | |
| num_layers = max(1, min(int(finetune_last_n_layers), num_layers)) | |
| if finetune_last_n_layers is not None: | |
| requested = int(finetune_last_n_layers) | |
| num_layers = max(1, min(requested, num_layers) if num_layers > 0 else requested) |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: b137b4058e
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| if finetune_last_n_layers is not None and num_layers > 0: | ||
| num_layers = max(1, min(int(finetune_last_n_layers), num_layers)) |
There was a problem hiding this comment.
Persist the selected MLX LoRA layer count
When finetune_last_n_layers is used, the clamped value is only kept in this local num_layers. The trainer still saves adapter_config.json with "num_layers": len(_get_transformer_layers(self.model)) in unsloth_zoo/mlx/trainer.py:1391-1396, so an adapter trained with e.g. the last 16 layers of an 18-layer model is advertised as covering all 18 layers to mlx-lm's load_adapters. In mlx-lm-compatible reload/resume paths this recreates extra active/trainable LoRA modules that were not part of the training run, so the saved adapter no longer faithfully represents the selected layer scope; please record the selected count on the model or infer it when saving.
Useful? React with 👍 / 👎.
| # mlx_lm.tuner.utils is imported inside the function: | ||
| fake_mod = type(sys)("mlx_lm.tuner.utils") | ||
| fake_mod.linear_to_lora_layers = fake_linear_to_lora_layers | ||
| sys.modules["mlx_lm.tuner.utils"] = fake_mod |
There was a problem hiding this comment.
Restore the fake mlx-lm tuner module after the test
This assigns a synthetic mlx_lm.tuner.utils module directly into sys.modules and never restores the original stub/module. In any pytest run where later tests or code under test import other tuner utilities such as load_adapters, they will receive this fake module that only defines linear_to_lora_layers, making the outcome order-dependent. Use pytest's monkeypatch.setitem(sys.modules, ...) or save and restore the previous value.
Useful? React with 👍 / 👎.
| what we assert, not the side effects on a real architecture). | ||
| """ | ||
| import sys | ||
| import unsloth_zoo.mlx.loader as loader_mod |
|
|
||
| # Stub out the helpers get_peft_model uses internally so the test | ||
| # doesn't need to walk a real model tree. | ||
| import unsloth_zoo.mlx.loader as L |
Empirical bisection (gemma-3-270m-it, single-row LoRA memorization, n=15 seeds)Filing the latest probe numbers here so reviewers can see the full picture this PR addresses vs what it does not.
The
Why this PR is still the right first step: the layer-selection mismatch is the only one of the four that is also a user-facing semantics difference, not just a numerical/perf overhead. mlx-lm CLI users who switch to cf_loss safety net: every config above hits teacher-forced completion loss == 0 in 15/15 seeds. The model memorizes either way; only the first-token greedy argmax distribution differs. CI smoke gating on Will file separate issues / PRs for (2), (3), (4). |
Final empirical summary (Round BO, 75 cells)Five companion PRs landed in this MLX-vs-mlx-lm-CLI parity series:
Across Rounds BG-BO on What this PR specifically addresses: the most user-visible parity surface, where calling |
Per code-comment policy: parameter name is self-documenting and the clamp is obvious from max(1, min(...)). Rationale lives in the commit message of b137b40 and the PR description.
danielhanchen
left a comment
There was a problem hiding this comment.
Thank you for the PR! The goal of this PR is to give MLX callers a one-knob way to match mlx-lm CLI's CONFIG_DEFAULTS['num_layers']=16 semantics (LoRA on the last N transformer blocks). As a summary, this PR adds an optional finetune_last_n_layers keyword to FastMLXModel.get_peft_model and, when set, clamps it via max(1, min(int(N), total)) and passes it as num_layers to linear_to_lora_layers on both the VLM language path and the text path. Default None preserves the current all-layers behavior.
Two independent Opus reviewers were run in parallel on this PR.
| Reviewers | Severity | Finding |
|---|---|---|
| 2/2 | Med | get_peft_model docstring is not updated, so the new parameter is invisible to help(...) and inspect-style discovery. |
| 2/2 | Med | int(finetune_last_n_layers) silently accepts True/False, floats (1.7 -> 1), and numeric strings, masking common user typos. |
| 2/2 | Med | Tests skip the VLM branch entirely — only the text-only call site (line 2905) is exercised. |
| 1/2 | Med | When the model lacks .model.layers (so num_layers == 0), finetune_last_n_layers is silently dropped with no warning. |
| 2/2 | Nit | The clamp logic is duplicated verbatim between the VLM and text branches; extract a _resolve_num_layers(num_layers, finetune_last_n_layers) helper. |
| 2/2 | Nit | The test monkeypatches sys.modules['mlx_lm.tuner.utils'] (and L._fix_missing_no_grad etc.) without a teardown — leaks into later tests if execution order shifts. Prefer pytest's monkeypatch fixture. |
Overall: APPROVE_WITH_NITS.
See inline comments for details and suggested fixes.
| finetune_language_layers=True, | ||
| finetune_attention_modules=True, | ||
| finetune_mlp_modules=True, | ||
| finetune_last_n_layers=None, |
There was a problem hiding this comment.
[2/2 reviewers] Med: the new parameter is added to the signature but not documented anywhere users can find it. The get_peft_model docstring just above this hunk does not mention finetune_last_n_layers, its meaning ("last N transformer blocks"), the clamp range, or the mlx-lm CLI parity intent.
| finetune_last_n_layers=None, | |
| finetune_last_n_layers=None, # mlx-lm CLI parity: LoRA on last N transformer blocks; default None = all layers |
| if hasattr(lm, "model") and hasattr(lm.model, "layers"): | ||
| num_layers = len(lm.model.layers) | ||
| if finetune_last_n_layers is not None and num_layers > 0: | ||
| num_layers = max(1, min(int(finetune_last_n_layers), num_layers)) |
There was a problem hiding this comment.
[2/2 reviewers] Med: int(finetune_last_n_layers) happily accepts True (→1), False (→0), floats (1.7 → 1, truncating silently), and numeric strings — all common typos. A True literal silently becomes "last 1 layer" instead of raising. Add an explicit type/range guard with a clear message; consider extracting the whole clamp into a helper since the same lines appear at line 2906 too.
| num_layers = max(1, min(int(finetune_last_n_layers), num_layers)) | |
| if finetune_last_n_layers is not None and num_layers > 0: | |
| if not isinstance(finetune_last_n_layers, int) or isinstance(finetune_last_n_layers, bool): | |
| raise TypeError( | |
| f"finetune_last_n_layers must be an int, got {type(finetune_last_n_layers).__name__}" | |
| ) | |
| num_layers = max(1, min(finetune_last_n_layers, num_layers)) |
| num_layers = 0 | ||
| if hasattr(lm, "model") and hasattr(lm.model, "layers"): | ||
| num_layers = len(lm.model.layers) | ||
| if finetune_last_n_layers is not None and num_layers > 0: |
There was a problem hiding this comment.
[1/2 reviewers] Med: when hasattr(lm, 'model') is False or lm.model.layers is empty, num_layers stays at 0 and the num_layers > 0 guard silently drops the user's finetune_last_n_layers request. linear_to_lora_layers is then called with num_layers=0, which is a no-op. Surface this rather than silently ignoring the setting.
| if finetune_last_n_layers is not None and num_layers > 0: | |
| if finetune_last_n_layers is not None: | |
| if num_layers > 0: | |
| num_layers = max(1, min(int(finetune_last_n_layers), num_layers)) | |
| else: | |
| import warnings | |
| warnings.warn( | |
| "Unsloth: finetune_last_n_layers requested but the model does not expose .model.layers; ignoring.", | |
| stacklevel=2, | |
| ) |
| if hasattr(model, "model") and hasattr(model.model, "layers"): | ||
| num_layers = len(model.model.layers) | ||
| if finetune_last_n_layers is not None and num_layers > 0: | ||
| num_layers = max(1, min(int(finetune_last_n_layers), num_layers)) |
There was a problem hiding this comment.
[2/2 reviewers] Nit: this 2-line clamp is identical to the VLM branch above (line 2823-2824). Extract a tiny helper so the two paths cannot drift:
| num_layers = max(1, min(int(finetune_last_n_layers), num_layers)) | |
| num_layers = _resolve_finetune_last_n_layers(num_layers, finetune_last_n_layers) |
(with a free function defined once near the top of FastMLXModel:
def _resolve_finetune_last_n_layers(num_layers, n):
if n is None or num_layers <= 0:
return num_layers
return max(1, min(int(n), num_layers))
```)| class FakeLayer: pass | ||
| class FakeInner: | ||
| layers = [FakeLayer() for _ in range(8)] | ||
| class FakeModel: |
There was a problem hiding this comment.
[2/2 reviewers] Med: the FakeModel here sets _is_vlm_model = False, so only the text-only call site (loader.py:2905) is exercised. The new clamp on line 2823 (VLM language path) has zero coverage even though the PR description says "wired into both VLM and text-only code paths". Add a second fixture that flips _is_vlm_model = True with a stubbed language_model so both call sites are pinned.
| class FakeModel: | |
| class FakeModel: | |
| model = FakeInner() | |
| _unsloth_full_finetuning = False | |
| _is_vlm_model = False | |
| def freeze(self): pass | |
| def unfreeze(self, **kwargs): pass | |
| def trainable_parameters(self): return {} | |
| def parameters(self): return {} | |
| class FakeVLMInner: | |
| layers = [FakeLayer() for _ in range(8)] | |
| class FakeVLM: | |
| language_model = type("LM", (), {"model": FakeVLMInner()}) | |
| _unsloth_full_finetuning = False | |
| _is_vlm_model = True | |
| def freeze(self): pass | |
| def unfreeze(self, **kwargs): pass | |
| def trainable_parameters(self): return {} | |
| def parameters(self): return {} |
| # Case 1: default (None) -> all 8 layers | ||
| captured["calls"].clear() | ||
| loader_mod.FastMLXModel.get_peft_model( | ||
| FakeModel(), |
There was a problem hiding this comment.
[2/2 reviewers] Nit: this monkeypatches sys.modules['mlx_lm.tuner.utils'] and four L.* attributes with no teardown. If pytest collection order changes, downstream tests in the same session will see the stubs instead of the real module. Use the monkeypatch fixture so changes are reverted automatically.
| FakeModel(), | |
| def test_get_peft_model_passes_finetune_last_n_layers_through(monkeypatch): | |
| import sys | |
| import unsloth_zoo.mlx.loader as loader_mod | |
| # ... build FakeModel ... | |
| captured = {"calls": []} | |
| def fake_linear_to_lora_layers(model, num_layers, config, use_dora=False): | |
| captured["calls"].append(num_layers) | |
| monkeypatch.setattr(loader_mod, "_fix_missing_no_grad", lambda m: None) | |
| monkeypatch.setattr(loader_mod, "_resolve_lora_keys", lambda m, t: [ | |
| "model.layers.0.self_attn.q_proj", | |
| "model.layers.1.mlp.gate_proj", | |
| ]) | |
| monkeypatch.setattr(loader_mod, "_apply_mlx_lora_initialization", lambda m, init: None) | |
| monkeypatch.setattr(loader_mod, "linear_to_lora_layers", fake_linear_to_lora_layers) | |
| fake_mod = type(sys)("mlx_lm.tuner.utils") | |
| fake_mod.linear_to_lora_layers = fake_linear_to_lora_layers | |
| monkeypatch.setitem(sys.modules, "mlx_lm.tuner.utils", fake_mod) |
…ai#678) Re-merge of PR unslothai#674 — the original was accidentally merged into the stale fix-mlx-num-layers-parity branch (after unslothai#669 had already squashed into main), leaving this fix stranded. FastMLXModel.get_peft_model previously called `_seed_mlx_random_state(random_state)` near the top of the method, ~100+ source lines above the actual `linear_to_lora_layers` call. In between sit target-module normalization, `_fix_missing_no_grad`, `_resolve_lora_keys`, and (on the VLM branch) the model-tree walk. Empirically this leaves a window in which lazy MLX state mutations or implicit `mx.random` consumption can slip in, so the lora_a matrices initialized inside `linear_to_lora_layers` end up DIFFERENT from mlx-lm CLI's, which seeds at `mlx_lm/tuner/lora.py` (def train) immediately before `linear_to_lora_layers`. Verified by probe 39 on danielhanchen/unsloth-staging-2: seed= 1, 42, 999, 3407, 22222: max |dloss| = max |dgrad_norm| = 0.0 across all 30 steps × 5 seeds (vs non-zero deltas before). This fix moves `_seed_mlx_random_state(random_state)` to immediately before each `linear_to_lora_layers(...)` call -- both VLM language path and text path. API surface unchanged. Test `tests/test_mlx_get_peft_model_seed_ordering.py` pins: 1. Every linear_to_lora_layers call inside get_peft_model is preceded by `_seed_mlx_random_state` within 20 lines. 2. The `random_state` API parameter still exists with default 3407. 3. The tight pairing matches BOTH the VLM and text LoRA call sites (regex tripwire that only allows comment lines between).
…nslothai#739) test_get_peft_model_passes_finetune_last_n_layers_through has failed since it was introduced in unslothai#669: the trainable parameter summary that get_peft_model prints (added in unslothai#634) calls model.trainable_parameters() and model.parameters(), which the synthetic FakeModel never stubbed. CI never executed the test body (collect-only plus exclusion list), so the failure stayed hidden. Give the fixture the two methods returning empty trees, matching the fixtures in test_mlx_save_lora_adapters_filter, so the summary computes 0 of 0 params and the num_layers assertions are exercised as intended.
…gate (#755) test_mlx_finetune_last_n_layers was born broken in #669 and stayed invisible until #739 because no CI job executed it: the version matrix only collects, the macOS MLX job runs the shim smoke test alone, and the zoo-specific CPU list does not include it. Add a small hard-gate step in repo-tests-cpu running it together with test_training_utils_use_cache (the use_cache disable/restore contract from #715). Both files are CPU-pure and run in under a second, and the job already installs the deps they need.
Summary
finetune_last_n_layersparameter toFastMLXModel.get_peft_model(defaultNone= all layers, current behavior unchanged).CONFIG_DEFAULTS['num_layers']=16semantics atmlx_lm/lora.py:56).unsloth/unslothexposes the same knob on the CUDA path so a single config value controls layer-selection across CUDA / MLX / mlx-lm CLI.Why
mlx-lm CLI defaults
num_layers=16-> LoRA on the LAST 16 transformer blocks. unsloth-zoo'sget_peft_modelhistorically applied LoRA to ALL transformer layers (matching HF PEFT/CUDA semantics).On small models the difference can show up as a basin-selection divergence: the extra LoRA modules consume
mx.randomstate during init and change the trainable-parameter set, so two otherwise-identical runs land in different basins of attraction.Empirical (n=15 seeds, gemma-3-270m-it single-row LoRA memorization fixture):
0in both — the model memorizes either way; only the first-token argmax distribution differs.This PR keeps the default behavior unchanged (None = all layers) so existing users see no change. Passing
finetune_last_n_layers=16puts the run in the same basin family as mlx-lm CLI for direct comparisons.The value is clamped to
[1, len(model.model.layers)]so callers can't accidentally request more layers than the model has, or zero layers (which would freeze everything).Test plan
tests/test_mlx_finetune_last_n_layers.pycovering:None-> num_layers = totaldanielhanchen/unsloth-staging-2MLX parity probe matrix — probe 31 with num_layers=16 hits 10/15 = 67% matching mlx-lm CLI per-seed).