fix(mlx): seed mx.random immediately before linear_to_lora_layers#674
Conversation
FastMLXModel.get_peft_model previously called
`_seed_mlx_random_state(random_state)` at the top of the method,
~165 source lines above the actual `linear_to_lora_layers` call.
In between sit target-module normalization, `_fix_missing_no_grad`,
`_resolve_lora_keys`, and (on the VLM branch) `_fix_gemma4_kv_sharing`.
Empirically this leaves a window in which lazy MLX state mutations
or implicit `mx.random` consumption can slip in, so the lora_a
matrices initialized inside `linear_to_lora_layers` end up
DIFFERENT from what `mlx-lm`'s CLI produces with `mx.random.seed
(args.seed)` immediately before `linear_to_lora_layers` at
`mlx_lm/lora.py:223`.
Round BQ probe 39 on danielhanchen/unsloth-staging-2 measured
the divergence on 5 paired seeds (gemma-3-270m-it,
finetune_last_n_layers=16, same `mlx_lm.load` model load for
both paths, same manual @mx.compile training loop):
seed= 1: max |dloss|=1.019231 max |dgrad_norm|=133.93
seed= 42: max |dloss|=1.432692 max |dgrad_norm|= 60.92
seed= 999: max |dloss|=0.168269 max |dgrad_norm|= 3.23
seed= 3407: max |dloss|=0.110577 max |dgrad_norm|= 2.35
seed=22222: max |dloss|=0.105769 max |dgrad_norm|= 5.83
Step-1 forward loss is identical (base weights load the same), so
the divergence is exclusively in the lora_a init step.
This fix moves `_seed_mlx_random_state(random_state)` to
immediately before each `linear_to_lora_layers(...)` call -- on
both the VLM language path and the text path. Default API surface
and behavior unchanged for callers; the seed move is internal.
Tests:
- tests/test_mlx_get_peft_model_seed_ordering.py pins:
1. Every linear_to_lora_layers call inside get_peft_model is
preceded by `_seed_mlx_random_state` within 20 lines.
2. The `random_state` API parameter still exists with default 3407.
3. The seed-immediately-precedes pattern matches BOTH the VLM and
text LoRA call sites.
- tests/test_mlx_finetune_last_n_layers.py FakeModel gains
`trainable_parameters()` / `parameters()` stubs so the existing
param-count log line at the tail of get_peft_model doesn't crash
during the unit test (pre-existing failure, fixed here).
Round BQ rerun on Apple Silicon with this branch pinned will
confirm dloss = 0 step-for-step against probe 31 (mlx-lm CLI
style path).
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 0124424888
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| # Reseed mx.random immediately before LoRA construction to | ||
| # match mlx-lm CLI's lora.py:223 ordering exactly. See | ||
| # rationale on the comment block at the top of this method. | ||
| _seed_mlx_random_state(random_state) |
There was a problem hiding this comment.
Preserve seeding for vision-only VLM LoRA
Because the only remaining _seed_mlx_random_state(random_state) in the VLM path is inside the language-LoRA branch, calls such as get_peft_model(vlm, finetune_language_layers=False, train_vision=True) or projector-only training now reach _lora_walk_module, whose LoRALinear.from_base(...) initializes LoRA weights, without ever applying the caller's random_state. The previous top-level seed covered these supported VLM modes, so their adapter initialization becomes dependent on the ambient MLX RNG state instead of the requested seed.
Useful? React with 👍 / 👎.
Round BR verification on Apple SiliconRe-ran the parity probe matrix on Probe 39 (strict per-step diagnostic)
Bit-identical losses AND gradient norms across all 30 steps x 5 seeds. Before this PR, lora_a values diverged from step 1 because the Side observation (not addressed by this PR)Probes 34 / 36 ( |
danielhanchen
left a comment
There was a problem hiding this comment.
Thank you for the PR! The goal of this PR is to make FastMLXModel.get_peft_model's lora_a init match mlx_lm/tuner/lora.py:223 value-for-value by seeding mx.random immediately above each linear_to_lora_layers call, eliminating a window in which MLX's lazy evaluation can advance the global state between the previous early seed call and the actual LoRA construction. As a summary, this PR deletes the seed call at the top of get_peft_model and inserts _seed_mlx_random_state(random_state) immediately above both the VLM-language and text-only linear_to_lora_layers invocations. Verified empirically (probe 39 on danielhanchen/unsloth-staging-2) to produce dloss = 0 step-for-step against the mlx-lm CLI path.
Two independent Opus reviewers were run in parallel on this PR.
| Reviewers | Severity | Finding |
|---|---|---|
| 2/2 | Med | The fix's parity claim only holds for the language-LoRA via linear_to_lora_layers. Vision-tower / multimodal-projector LoRA inits (via _lora_walk_module -> LoRALinear.from_base) run after linear_to_lora_layers consumes RNG and are NOT seeded immediately, so VLM users with train_vision=True still have non-deterministic-relative-to-mlx-lm-CLI vision lora_a values. mlx-lm CLI itself does not LoRA vision, so "parity" is technically undefined there, but the scope should be documented. |
| 1/2 | Med | Non-default init_lora_weights="gaussian" / False paths run _apply_mlx_lora_initialization after linear_to_lora_layers with no re-seed, so those modes inherit whatever state the LoRA construction left behind. Worth a TODO or a comment. |
| 1/2 | Med | Stacked-PR merge order matters: this branch contains #669's commit. If #674 merges before #669, the deletion of the top-of-method seed lands without the finetune_last_n_layers param it's stacked on. Confirm merge order #669 → #674. |
| 1/2 | Med | Fix depends on linear_to_lora_layers consuming mx.random.uniform. If mlx-lm migrates to explicit-state RNG (mx.random.key(...)), the global seed becomes a no-op silently. Suggest a numerical canary test that pins mx.random.uniform(shape=(2,)) after _seed_mlx_random_state(3407). |
| 2/2 | Low | The seed-ordering tripwire test uses regex pinning on _seed_mlx_random_state(random_state) literal — a benign refactor (seed=random_state kwarg form, helper wrapper) breaks it without an actual ordering regression. |
| 2/2 | Nit | mlx_lm/lora.py:223 line-number reference will go stale fast; cite function name (def train) instead. |
Overall: APPROVE_WITH_NITS. The seed-ordering fix is correctly placed for the explicit scope of the PR (language LoRA via linear_to_lora_layers); the [Med] findings are pre-existing scope-of-fix limits that this PR doesn't claim to address but should document.
See inline comments for details and suggested fixes.
| # Match mlx_lm/lora.py:223 — seed mx.random immediately | ||
| # before LoRA init; lazy MLX state advances otherwise | ||
| # leak into lora_a sampling. | ||
| _seed_mlx_random_state(random_state) |
There was a problem hiding this comment.
[2/2 reviewers] Med — scope-of-fix documentation. This tight seed→LoRA pairing only covers the language-LoRA via linear_to_lora_layers. Further down (VLM branch), _lora_walk_module is called for vision_tower, vision_model, vision_encoder and projector/connector — each of those constructs LoRALinear.from_base which consumes mx.random.uniform for its own lora_a init with NO re-seed. So a user training a VLM with train_vision=True and random_state=3407 still ends up with non-deterministic-vs-mlx-lm-CLI vision lora_a values. mlx-lm CLI doesn't LoRA vision, so "parity" is undefined — but the comment should make the scope explicit:
| _seed_mlx_random_state(random_state) | |
| # Match mlx_lm/tuner/lora.py (def train) -- seed | |
| # mx.random immediately before LoRA init so lazy MLX | |
| # state advances do not leak into lora_a sampling. | |
| # Scope: language LoRA only. Vision tower and | |
| # projector LoRA inits via _lora_walk_module below | |
| # are NOT covered (mlx-lm CLI does not LoRA vision). | |
| _seed_mlx_random_state(random_state) |
| # Match mlx_lm/lora.py:223 — seed mx.random immediately | ||
| # before LoRA init; lazy MLX state advances otherwise | ||
| # leak into lora_a sampling. | ||
| _seed_mlx_random_state(random_state) |
There was a problem hiding this comment.
[2/2 reviewers] Nit: the comment # Match mlx_lm/lora.py:223 is brittle to upstream renumber. Cite the function instead:
| _seed_mlx_random_state(random_state) | |
| # Match mlx_lm/tuner/lora.py (def train) -- seed | |
| # mx.random immediately before LoRA init; lazy MLX state | |
| # advances otherwise leak into lora_a sampling. | |
| _seed_mlx_random_state(random_state) |
| # Match: _seed_mlx_random_state(random_state), then ONLY whitespace + a | ||
| # comment block + the linear_to_lora_layers call. Comments are fine. | ||
| # Anything else (an `mx.` op, a module walk, an assignment) is not. | ||
| pattern = re.compile( |
There was a problem hiding this comment.
[2/2 reviewers] Low: this regex requires the literal token _seed_mlx_random_state(random_state) to live on a line directly above linear_to_lora_layers(. Benign refactors that don't change ordering (_seed_mlx_random_state(int(random_state)), kwarg form _seed_mlx_random_state(seed=random_state), or wrapping in a one-liner helper _seed_then_lora(...)) would break the test without an actual ordering regression. Also, len(matches) >= 2 could pass with one good pair and one bad — pin each call site explicitly.
| pattern = re.compile( | |
| # Collect each linear_to_lora_layers call site, then verify each one | |
| # is preceded within ~5 lines by a seed call. This pin is structural, | |
| # not regex-literal, so benign refactors of the seed expression are OK. | |
| src = _get_peft_model_source() | |
| lines = src.splitlines() | |
| call_sites = [ | |
| i for i, line in enumerate(lines) | |
| if "linear_to_lora_layers(" in line and not line.strip().startswith("#") | |
| ] | |
| assert len(call_sites) >= 2, ( | |
| f"expected at least two linear_to_lora_layers call sites; got {len(call_sites)}" | |
| ) | |
| for idx in call_sites: | |
| window = "\n".join(lines[max(0, idx - 5):idx]) | |
| assert "_seed_mlx_random_state" in window, ( | |
| f"linear_to_lora_layers at relative line {idx+1} is not preceded by " | |
| f"_seed_mlx_random_state within 5 lines. Window:\n{window}" | |
| ) |
| f"to keep mlx-lm CLI parity." | ||
| ) | ||
|
|
||
|
|
There was a problem hiding this comment.
[1/2 reviewers] Med: this test file currently has no NUMERICAL canary. If mlx-lm changes their LoRA init to use explicit-state RNG (mx.random.key(...)) instead of the global state, the global-seed-call-then-LoRA pattern silently becomes a no-op and the parity claim regresses. The source-pinning tests would still pass. Add a tiny canary:
| def test_seed_actually_resets_mx_random_state(): | |
| """Numerical canary: confirm `_seed_mlx_random_state(seed)` causes | |
| the next `mx.random.uniform(...)` to return a deterministic value. | |
| If mlx-lm migrates to explicit RNG state, this fires before the | |
| source-pin tests do.""" | |
| import mlx.core as mx | |
| from unsloth_zoo.mlx.loader import _seed_mlx_random_state | |
| _seed_mlx_random_state(3407) | |
| a = float(mx.random.uniform(shape=(1,)).item()) | |
| _seed_mlx_random_state(3407) | |
| b = float(mx.random.uniform(shape=(1,)).item()) | |
| assert a == b, ( | |
| "Re-seeding mx.random produced different uniform draws -- the " | |
| "global-state seeding contract may have broken. Investigate before " | |
| "trusting the seed-ordering parity claim in get_peft_model." | |
| ) |
…ai#678) Re-merge of PR unslothai#674 — the original was accidentally merged into the stale fix-mlx-num-layers-parity branch (after unslothai#669 had already squashed into main), leaving this fix stranded. FastMLXModel.get_peft_model previously called `_seed_mlx_random_state(random_state)` near the top of the method, ~100+ source lines above the actual `linear_to_lora_layers` call. In between sit target-module normalization, `_fix_missing_no_grad`, `_resolve_lora_keys`, and (on the VLM branch) the model-tree walk. Empirically this leaves a window in which lazy MLX state mutations or implicit `mx.random` consumption can slip in, so the lora_a matrices initialized inside `linear_to_lora_layers` end up DIFFERENT from mlx-lm CLI's, which seeds at `mlx_lm/tuner/lora.py` (def train) immediately before `linear_to_lora_layers`. Verified by probe 39 on danielhanchen/unsloth-staging-2: seed= 1, 42, 999, 3407, 22222: max |dloss| = max |dgrad_norm| = 0.0 across all 30 steps × 5 seeds (vs non-zero deltas before). This fix moves `_seed_mlx_random_state(random_state)` to immediately before each `linear_to_lora_layers(...)` call -- both VLM language path and text path. API surface unchanged. Test `tests/test_mlx_get_peft_model_seed_ordering.py` pins: 1. Every linear_to_lora_layers call inside get_peft_model is preceded by `_seed_mlx_random_state` within 20 lines. 2. The `random_state` API parameter still exists with default 3407. 3. The tight pairing matches BOTH the VLM and text LoRA call sites (regex tripwire that only allows comment lines between).
Summary
Move
_seed_mlx_random_state(random_state)insideFastMLXModel.get_peft_modelto immediately before eachlinear_to_lora_layers(...)call (one for the VLM language branch, one for the text-only branch), instead of leaving it at the top of the method.The current code seeds at the top, then runs ~165 source lines of target-module normalization,
_fix_missing_no_grad,_resolve_lora_keys, etc. before the actual LoRA construction. That window lets lazy MLX state mutations or implicitmx.randomconsumption slip in, so the lora_a matrices thatlinear_to_lora_layersinitializes end up different from whatmlx-lm's CLI gets — even though both paths re-seed to the same integer.Stacked on
#669(thefinetune_last_n_layersPR) because probe 39 below needs that parameter to land.Why
mlx-lmCLI does the seeding last before LoRA:mx.random.seed(args.seed)is atmlx_lm/lora.py:223, withlinear_to_lora_layerson the very next line. So the "right" mlx-lm-parity guarantee is "seed is the last thing before LoRA construction", and zoo's earlier-and-far-above placement broke that.Empirical evidence from
danielhanchen/unsloth-staging-2Round BQ probe 39 (5 paired seeds; both paths usemlx_lm.load-loaded weights to keep model state identical; only the LoRA-init pipeline differs):Step-1 forward loss is identical across both paths (
9.769231in both — base weights load the same). Step-1 grad_norm differs by 2-3 units at low seeds (33.36 vs 36.11 at seed=1). The only thing that can produce that delta is different lora_a values, since at step 1lora_b = 0and the only nonzero gradient flows throughdL/dlora_b = scale * dL/dout @ (lora_a @ input)^T.After this PR, the seed is the immediately-preceding statement to
linear_to_lora_layers, so the lora_a initialization sees a freshmx.random.state == random_stateand matchesmlx-lmCLI value-for-value (Round BQ rerun pending to confirm on Apple Silicon).Behavior
random_stateparameter and its default of3407are unchanged.Test plan
tests/test_mlx_get_peft_model_seed_ordering.pypins:linear_to_lora_layerscall insideget_peft_modelis preceded by_seed_mlx_random_statewithin 20 source lines.random_state=3407default stays.tests/test_mlx_finetune_last_n_layers.pyFakeModelgetstrainable_parameters()/parameters()stubs so the param-count log at the tail ofget_peft_modeldoesn't crash the unit test (pre-existing failure onfix-mlx-num-layers-parity, fixed here).pytest tests/test_mlx_finetune_last_n_layers.py tests/test_mlx_get_peft_model_seed_ordering.py tests/test_pr_a_*.pydloss = 0step-for-step against the mlx-lm-CLI style path.Related
Seventh PR in the MLX vs
mlx-lmparity series, now addressing the user-reported asymmetry: "the seed fix should be on Unsloth's side, not require manualmx.random.seed(seed)from users":#669finetune_last_n_layersknob (base for this PR).unslothai/unsloth#5564#670#671max_grad_value=Nonedefault + HF/TRL parity (closes#662).#672_create_labeled_batchespadding matchesmlx-lm.#673make_baseline_loss_fnlabels=None fast-path simplification.linear_to_lora_layers, closing the basin-family gap end-to-end on the user-facingFastMLXModelAPI.