fix(mlx): seed mx.random immediately before linear_to_lora_layers (re-PR of #674)#678
Conversation
Re-merge of PR #674 — the original was accidentally merged into the stale fix-mlx-num-layers-parity branch (after #669 had already squashed into main), leaving this fix stranded. FastMLXModel.get_peft_model previously called `_seed_mlx_random_state(random_state)` near the top of the method, ~100+ source lines above the actual `linear_to_lora_layers` call. In between sit target-module normalization, `_fix_missing_no_grad`, `_resolve_lora_keys`, and (on the VLM branch) the model-tree walk. Empirically this leaves a window in which lazy MLX state mutations or implicit `mx.random` consumption can slip in, so the lora_a matrices initialized inside `linear_to_lora_layers` end up DIFFERENT from mlx-lm CLI's, which seeds at `mlx_lm/tuner/lora.py` (def train) immediately before `linear_to_lora_layers`. Verified by probe 39 on danielhanchen/unsloth-staging-2: seed= 1, 42, 999, 3407, 22222: max |dloss| = max |dgrad_norm| = 0.0 across all 30 steps × 5 seeds (vs non-zero deltas before). This fix moves `_seed_mlx_random_state(random_state)` to immediately before each `linear_to_lora_layers(...)` call -- both VLM language path and text path. API surface unchanged. Test `tests/test_mlx_get_peft_model_seed_ordering.py` pins: 1. Every linear_to_lora_layers call inside get_peft_model is preceded by `_seed_mlx_random_state` within 20 lines. 2. The `random_state` API parameter still exists with default 3407. 3. The tight pairing matches BOTH the VLM and text LoRA call sites (regex tripwire that only allows comment lines between).
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: b515a77803
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| # Match mlx_lm/tuner/lora.py (def train) -- seed | ||
| # mx.random immediately before LoRA init; lazy MLX | ||
| # state advances otherwise leak into lora_a sampling. | ||
| _seed_mlx_random_state(random_state) |
There was a problem hiding this comment.
Seed VLM-only LoRA branches before wrapping
When a VLM call sets finetune_language_layers=False but enables train_vision and/or train_projector, execution skips this newly relocated seed and then _lora_walk_module wraps those branches with LoRALinear.from_base, so their LoRA matrices are initialized from whatever global MLX RNG state happens to exist and random_state is ignored. The removed top-level seed used to cover these supported VLM-only fine-tuning modes; add an immediate seed before the vision/projector _lora_walk_module calls as well, or keep a fallback seed for paths that do not call linear_to_lora_layers.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Code Review
This pull request refactors the MLX random state seeding logic in get_peft_model by moving the _seed_mlx_random_state call from a global position to immediately before each linear_to_lora_layers invocation. This change prevents lazy MLX state leakage, ensuring deterministic LoRA initialization. New tests were added to verify this seeding order and maintain API surface consistency. Feedback suggests ensuring that all paths, including VLM-specific branches, are covered by this seeding logic and improving the robustness of the test assertions by comparing the seed-call count against the total number of LoRA layer initializations.
I am having trouble creating individual review comments. Click here to see my feedback.
unsloth_zoo/mlx/loader.py (2744-2746)
Removing the global _seed_mlx_random_state(random_state) call at the start of get_peft_model ensures that the language LoRA initialization is clean, but it leaves other paths—specifically train_vision and train_projector in the VLM branch—without a deterministic seed if finetune_language_layers is False. Since those call sites are not included in this diff, please ensure they are also updated to include a preceding seed call to maintain full reproducibility across all PEFT configurations.
tests/test_mlx_get_peft_model_seed_ordering.py (83-90)
The assertion assert len(matches) >= 2 is weak because it doesn't verify that all calls to linear_to_lora_layers are correctly seeded. If a new call were added without a seed, this test would still pass as long as the existing two remain. It would be more robust to compare the number of matches against the total number of linear_to_lora_layers calls found in the source.
matches = pattern.findall(src)
lines = src.splitlines()
call_count = len([l for l in lines if "linear_to_lora_layers(" in l and not l.strip().startswith("#")])
assert len(matches) == call_count, (
f"Expected {call_count} seed+LoRA-call tight pairings in "
f"get_peft_model, but found {len(matches)}. Every "
"linear_to_lora_layers invocation must be immediately preceded by a _seed_mlx_random_state call.
)References
- Verify the fix using tests that check for line structure preservation.
- It is acceptable to use fragile string-matching for code patching if it is consistent with the existing codebase's architecture.
Summary
PR #674 was accidentally merged into the stale
fix-mlx-num-layers-paritybranch (which had already been squashed-and-superseded into main as #669), so the seed-ordering fix never reachedmain. This PR re-applies that fix cleanly against the currentmain.What this fixes
FastMLXModel.get_peft_modelpreviously called_seed_mlx_random_state(random_state)near the top of the method, ~100+ source lines above the actuallinear_to_lora_layerscall. Between them sittarget_modulesnormalization,_fix_missing_no_grad,_resolve_lora_keys, and (on the VLM branch) model-tree walks. MLX's lazy evaluation can advancemx.random.statebetween the early seed call andlora_ainit viamx.random.uniform, producinglora_amatrices that differ from mlx-lm CLI's despite both paths seeding to the same int.This PR moves
_seed_mlx_random_state(random_state)to immediately before eachlinear_to_lora_layers(...)call (both VLM language path and text path). Matches mlx-lm's own ordering inmlx_lm/tuner/lora.py(the seed is the last thing it does beforelinear_to_lora_layers).Verification
Probe 39 on
danielhanchen/unsloth-staging-2(5 seeds × 30 steps, gemma-3-270m-it, FastMLXModel vs mlx-lm CLI through the same manual training loop):Bit-identical losses AND gradient norms across all 30 steps × 5 seeds. Before the fix, lora_a values diverged from step 1.
Test plan
pytest tests/test_mlx_get_peft_model_seed_ordering.py→ 3/3 pass (source-pinning tripwires).Stranded branch
For history: the original PR #674 lives in
fix-mlx-num-layers-parityat commit74902b2a. That branch is now obsolete and can be deleted.