Skip to content

fix(mlx): seed mx.random immediately before linear_to_lora_layers#674

Merged
danielhanchen merged 2 commits into
fix-mlx-num-layers-parityfrom
fix-mlx-get-peft-model-seed
May 19, 2026
Merged

fix(mlx): seed mx.random immediately before linear_to_lora_layers#674
danielhanchen merged 2 commits into
fix-mlx-num-layers-parityfrom
fix-mlx-get-peft-model-seed

Conversation

@danielhanchen

Copy link
Copy Markdown
Member

Summary

Move _seed_mlx_random_state(random_state) inside FastMLXModel.get_peft_model to immediately before each linear_to_lora_layers(...) call (one for the VLM language branch, one for the text-only branch), instead of leaving it at the top of the method.

The current code seeds at the top, then runs ~165 source lines of target-module normalization, _fix_missing_no_grad, _resolve_lora_keys, etc. before the actual LoRA construction. That window lets lazy MLX state mutations or implicit mx.random consumption slip in, so the lora_a matrices that linear_to_lora_layers initializes end up different from what mlx-lm's CLI gets — even though both paths re-seed to the same integer.

Stacked on #669 (the finetune_last_n_layers PR) because probe 39 below needs that parameter to land.

Why

mlx-lm CLI does the seeding last before LoRA: mx.random.seed(args.seed) is at mlx_lm/lora.py:223, with linear_to_lora_layers on the very next line. So the "right" mlx-lm-parity guarantee is "seed is the last thing before LoRA construction", and zoo's earlier-and-far-above placement broke that.

Empirical evidence from danielhanchen/unsloth-staging-2 Round BQ probe 39 (5 paired seeds; both paths use mlx_lm.load-loaded weights to keep model state identical; only the LoRA-init pipeline differs):

seed max abs $\Delta$ loss max abs $\Delta$ grad_norm
1 1.019231 133.93
42 1.432692 60.92
999 0.168269 3.23
3407 0.110577 2.35
22222 0.105769 5.83

Step-1 forward loss is identical across both paths (9.769231 in both — base weights load the same). Step-1 grad_norm differs by 2-3 units at low seeds (33.36 vs 36.11 at seed=1). The only thing that can produce that delta is different lora_a values, since at step 1 lora_b = 0 and the only nonzero gradient flows through dL/dlora_b = scale * dL/dout @ (lora_a @ input)^T.

After this PR, the seed is the immediately-preceding statement to linear_to_lora_layers, so the lora_a initialization sees a fresh mx.random.state == random_state and matches mlx-lm CLI value-for-value (Round BQ rerun pending to confirm on Apple Silicon).

Behavior

  • Default API surface and behavior unchanged for callers. The seed move is purely internal.
  • random_state parameter and its default of 3407 are unchanged.
  • The VLM language LoRA branch (line 2842) and text-only branch (line 2933) each get the seed call moved right above them.

Test plan

  • tests/test_mlx_get_peft_model_seed_ordering.py pins:
    1. Every linear_to_lora_layers call inside get_peft_model is preceded by _seed_mlx_random_state within 20 source lines.
    2. The random_state=3407 default stays.
    3. The seed-immediately-precedes pattern matches both LoRA call sites (VLM + text).
  • tests/test_mlx_finetune_last_n_layers.py FakeModel gets trainable_parameters() / parameters() stubs so the param-count log at the tail of get_peft_model doesn't crash the unit test (pre-existing failure on fix-mlx-num-layers-parity, fixed here).
  • Local: pytest tests/test_mlx_finetune_last_n_layers.py tests/test_mlx_get_peft_model_seed_ordering.py tests/test_pr_a_*.py $\to$ 58 passed.
  • Round BQ rerun on Apple Silicon with this branch pinned will confirm dloss = 0 step-for-step against the mlx-lm-CLI style path.

Related

Seventh PR in the MLX vs mlx-lm parity series, now addressing the user-reported asymmetry: "the seed fix should be on Unsloth's side, not require manual mx.random.seed(seed) from users":

  • #669 $\to$ finetune_last_n_layers knob (base for this PR).
  • unslothai/unsloth#5564 $\to$ same knob on CUDA path.
  • #670 $\to$ warn on bf16 $\to$ fp16 downcast.
  • #671 $\to$ max_grad_value=None default + HF/TRL parity (closes #662).
  • #672 $\to$ _create_labeled_batches padding matches mlx-lm.
  • #673 $\to$ make_baseline_loss_fn labels=None fast-path simplification.
  • This PR $\to$ seeding ordered immediately before linear_to_lora_layers, closing the basin-family gap end-to-end on the user-facing FastMLXModel API.

FastMLXModel.get_peft_model previously called
`_seed_mlx_random_state(random_state)` at the top of the method,
~165 source lines above the actual `linear_to_lora_layers` call.
In between sit target-module normalization, `_fix_missing_no_grad`,
`_resolve_lora_keys`, and (on the VLM branch) `_fix_gemma4_kv_sharing`.

Empirically this leaves a window in which lazy MLX state mutations
or implicit `mx.random` consumption can slip in, so the lora_a
matrices initialized inside `linear_to_lora_layers` end up
DIFFERENT from what `mlx-lm`'s CLI produces with `mx.random.seed
(args.seed)` immediately before `linear_to_lora_layers` at
`mlx_lm/lora.py:223`.

Round BQ probe 39 on danielhanchen/unsloth-staging-2 measured
the divergence on 5 paired seeds (gemma-3-270m-it,
finetune_last_n_layers=16, same `mlx_lm.load` model load for
both paths, same manual @mx.compile training loop):

  seed=    1: max |dloss|=1.019231  max |dgrad_norm|=133.93
  seed=   42: max |dloss|=1.432692  max |dgrad_norm|= 60.92
  seed=  999: max |dloss|=0.168269  max |dgrad_norm|=  3.23
  seed= 3407: max |dloss|=0.110577  max |dgrad_norm|=  2.35
  seed=22222: max |dloss|=0.105769  max |dgrad_norm|=  5.83

Step-1 forward loss is identical (base weights load the same), so
the divergence is exclusively in the lora_a init step.

This fix moves `_seed_mlx_random_state(random_state)` to
immediately before each `linear_to_lora_layers(...)` call -- on
both the VLM language path and the text path. Default API surface
and behavior unchanged for callers; the seed move is internal.

Tests:
- tests/test_mlx_get_peft_model_seed_ordering.py pins:
  1. Every linear_to_lora_layers call inside get_peft_model is
     preceded by `_seed_mlx_random_state` within 20 lines.
  2. The `random_state` API parameter still exists with default 3407.
  3. The seed-immediately-precedes pattern matches BOTH the VLM and
     text LoRA call sites.
- tests/test_mlx_finetune_last_n_layers.py FakeModel gains
  `trainable_parameters()` / `parameters()` stubs so the existing
  param-count log line at the tail of get_peft_model doesn't crash
  during the unit test (pre-existing failure, fixed here).

Round BQ rerun on Apple Silicon with this branch pinned will
confirm dloss = 0 step-for-step against probe 31 (mlx-lm CLI
style path).
@gemini-code-assist

Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 0124424888

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread unsloth_zoo/mlx/loader.py
# Reseed mx.random immediately before LoRA construction to
# match mlx-lm CLI's lora.py:223 ordering exactly. See
# rationale on the comment block at the top of this method.
_seed_mlx_random_state(random_state)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve seeding for vision-only VLM LoRA

Because the only remaining _seed_mlx_random_state(random_state) in the VLM path is inside the language-LoRA branch, calls such as get_peft_model(vlm, finetune_language_layers=False, train_vision=True) or projector-only training now reach _lora_walk_module, whose LoRALinear.from_base(...) initializes LoRA weights, without ever applying the caller's random_state. The previous top-level seed covered these supported VLM modes, so their adapter initialization becomes dependent on the ambient MLX RNG state instead of the requested seed.

Useful? React with 👍 / 👎.

@danielhanchen

Copy link
Copy Markdown
Member Author

Round BR verification on Apple Silicon

Re-ran the parity probe matrix on danielhanchen/unsloth-staging-2 with ZOO_SPEC pinned to this branch's HEAD (0124424):
Run 26078606176 — 50 / 50 jobs green.

Probe 39 (strict per-step diagnostic)

FastMLXModel.from_pretrained + FastMLXModel.get_peft_model(finetune_last_n_layers=16) vs mlx_lm.load + linear_to_lora_layers(num_layers=16), both feeding the same manual @mx.compile training loop.

seed max |dloss| max |dgrad_norm|
1 0.0 0.0
42 0.0 0.0
999 0.0 0.0
3407 0.0 0.0
22222 0.0 0.0

Bit-identical losses AND gradient norms across all 30 steps x 5 seeds. Before this PR, lora_a values diverged from step 1 because the _seed_mlx_random_state(random_state) call sat about 165 lines above linear_to_lora_layers, leaving a window for lazy MLX state advances between seeding and LoRA init. Moving the seed call to immediately above each linear_to_lora_layers(...) (matching mlx_lm/tuner/lora.py:223) closes the gap.

Side observation (not addressed by this PR)

Probes 34 / 36 (FastMLXModel + MLXTrainer) still hit 47% greedy pass vs probe 31's (mlx-lm CLI manual loop) 67% on the same 15 seeds, with probes 34 and 36 sharing an identical seed pattern (so compile=True/False is a no-op for the basin). The residual gap is downstream of get_peft_model and is separate from what this PR fixes -- planning a follow-up bisection (Round BS).

Per code-comment policy: keep WHY (lazy MLX state advances leak into
lora_a sampling without tight seed ordering, mlx_lm/lora.py:223
reference for context). Drop the multi-paragraph rationale and probe
references — those live in commits 0124424 and b137b40's messages.

@danielhanchen danielhanchen left a comment

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR! The goal of this PR is to make FastMLXModel.get_peft_model's lora_a init match mlx_lm/tuner/lora.py:223 value-for-value by seeding mx.random immediately above each linear_to_lora_layers call, eliminating a window in which MLX's lazy evaluation can advance the global state between the previous early seed call and the actual LoRA construction. As a summary, this PR deletes the seed call at the top of get_peft_model and inserts _seed_mlx_random_state(random_state) immediately above both the VLM-language and text-only linear_to_lora_layers invocations. Verified empirically (probe 39 on danielhanchen/unsloth-staging-2) to produce dloss = 0 step-for-step against the mlx-lm CLI path.

Two independent Opus reviewers were run in parallel on this PR.

Reviewers Severity Finding
2/2 Med The fix's parity claim only holds for the language-LoRA via linear_to_lora_layers. Vision-tower / multimodal-projector LoRA inits (via _lora_walk_module -> LoRALinear.from_base) run after linear_to_lora_layers consumes RNG and are NOT seeded immediately, so VLM users with train_vision=True still have non-deterministic-relative-to-mlx-lm-CLI vision lora_a values. mlx-lm CLI itself does not LoRA vision, so "parity" is technically undefined there, but the scope should be documented.
1/2 Med Non-default init_lora_weights="gaussian" / False paths run _apply_mlx_lora_initialization after linear_to_lora_layers with no re-seed, so those modes inherit whatever state the LoRA construction left behind. Worth a TODO or a comment.
1/2 Med Stacked-PR merge order matters: this branch contains #669's commit. If #674 merges before #669, the deletion of the top-of-method seed lands without the finetune_last_n_layers param it's stacked on. Confirm merge order #669#674.
1/2 Med Fix depends on linear_to_lora_layers consuming mx.random.uniform. If mlx-lm migrates to explicit-state RNG (mx.random.key(...)), the global seed becomes a no-op silently. Suggest a numerical canary test that pins mx.random.uniform(shape=(2,)) after _seed_mlx_random_state(3407).
2/2 Low The seed-ordering tripwire test uses regex pinning on _seed_mlx_random_state(random_state) literal — a benign refactor (seed=random_state kwarg form, helper wrapper) breaks it without an actual ordering regression.
2/2 Nit mlx_lm/lora.py:223 line-number reference will go stale fast; cite function name (def train) instead.

Overall: APPROVE_WITH_NITS. The seed-ordering fix is correctly placed for the explicit scope of the PR (language LoRA via linear_to_lora_layers); the [Med] findings are pre-existing scope-of-fix limits that this PR doesn't claim to address but should document.

See inline comments for details and suggested fixes.

Comment thread unsloth_zoo/mlx/loader.py
# Match mlx_lm/lora.py:223 — seed mx.random immediately
# before LoRA init; lazy MLX state advances otherwise
# leak into lora_a sampling.
_seed_mlx_random_state(random_state)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[2/2 reviewers] Med — scope-of-fix documentation. This tight seed→LoRA pairing only covers the language-LoRA via linear_to_lora_layers. Further down (VLM branch), _lora_walk_module is called for vision_tower, vision_model, vision_encoder and projector/connector — each of those constructs LoRALinear.from_base which consumes mx.random.uniform for its own lora_a init with NO re-seed. So a user training a VLM with train_vision=True and random_state=3407 still ends up with non-deterministic-vs-mlx-lm-CLI vision lora_a values. mlx-lm CLI doesn't LoRA vision, so "parity" is undefined — but the comment should make the scope explicit:

Suggested change
_seed_mlx_random_state(random_state)
# Match mlx_lm/tuner/lora.py (def train) -- seed
# mx.random immediately before LoRA init so lazy MLX
# state advances do not leak into lora_a sampling.
# Scope: language LoRA only. Vision tower and
# projector LoRA inits via _lora_walk_module below
# are NOT covered (mlx-lm CLI does not LoRA vision).
_seed_mlx_random_state(random_state)

Comment thread unsloth_zoo/mlx/loader.py
# Match mlx_lm/lora.py:223 — seed mx.random immediately
# before LoRA init; lazy MLX state advances otherwise
# leak into lora_a sampling.
_seed_mlx_random_state(random_state)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[2/2 reviewers] Nit: the comment # Match mlx_lm/lora.py:223 is brittle to upstream renumber. Cite the function instead:

Suggested change
_seed_mlx_random_state(random_state)
# Match mlx_lm/tuner/lora.py (def train) -- seed
# mx.random immediately before LoRA init; lazy MLX state
# advances otherwise leak into lora_a sampling.
_seed_mlx_random_state(random_state)

# Match: _seed_mlx_random_state(random_state), then ONLY whitespace + a
# comment block + the linear_to_lora_layers call. Comments are fine.
# Anything else (an `mx.` op, a module walk, an assignment) is not.
pattern = re.compile(

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[2/2 reviewers] Low: this regex requires the literal token _seed_mlx_random_state(random_state) to live on a line directly above linear_to_lora_layers(. Benign refactors that don't change ordering (_seed_mlx_random_state(int(random_state)), kwarg form _seed_mlx_random_state(seed=random_state), or wrapping in a one-liner helper _seed_then_lora(...)) would break the test without an actual ordering regression. Also, len(matches) >= 2 could pass with one good pair and one bad — pin each call site explicitly.

Suggested change
pattern = re.compile(
# Collect each linear_to_lora_layers call site, then verify each one
# is preceded within ~5 lines by a seed call. This pin is structural,
# not regex-literal, so benign refactors of the seed expression are OK.
src = _get_peft_model_source()
lines = src.splitlines()
call_sites = [
i for i, line in enumerate(lines)
if "linear_to_lora_layers(" in line and not line.strip().startswith("#")
]
assert len(call_sites) >= 2, (
f"expected at least two linear_to_lora_layers call sites; got {len(call_sites)}"
)
for idx in call_sites:
window = "\n".join(lines[max(0, idx - 5):idx])
assert "_seed_mlx_random_state" in window, (
f"linear_to_lora_layers at relative line {idx+1} is not preceded by "
f"_seed_mlx_random_state within 5 lines. Window:\n{window}"
)

f"to keep mlx-lm CLI parity."
)


Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[1/2 reviewers] Med: this test file currently has no NUMERICAL canary. If mlx-lm changes their LoRA init to use explicit-state RNG (mx.random.key(...)) instead of the global state, the global-seed-call-then-LoRA pattern silently becomes a no-op and the parity claim regresses. The source-pinning tests would still pass. Add a tiny canary:

Suggested change
def test_seed_actually_resets_mx_random_state():
"""Numerical canary: confirm `_seed_mlx_random_state(seed)` causes
the next `mx.random.uniform(...)` to return a deterministic value.
If mlx-lm migrates to explicit RNG state, this fires before the
source-pin tests do."""
import mlx.core as mx
from unsloth_zoo.mlx.loader import _seed_mlx_random_state
_seed_mlx_random_state(3407)
a = float(mx.random.uniform(shape=(1,)).item())
_seed_mlx_random_state(3407)
b = float(mx.random.uniform(shape=(1,)).item())
assert a == b, (
"Re-seeding mx.random produced different uniform draws -- the "
"global-state seeding contract may have broken. Investigate before "
"trusting the seed-ordering parity claim in get_peft_model."
)

@danielhanchen danielhanchen merged commit 74902b2 into fix-mlx-num-layers-parity May 19, 2026
11 checks passed
Sekinal pushed a commit to Sekinal/unsloth-zoo that referenced this pull request May 19, 2026
…ai#678)

Re-merge of PR unslothai#674 — the original was accidentally merged into the
stale fix-mlx-num-layers-parity branch (after unslothai#669 had already squashed
into main), leaving this fix stranded.

FastMLXModel.get_peft_model previously called
`_seed_mlx_random_state(random_state)` near the top of the method,
~100+ source lines above the actual `linear_to_lora_layers` call.
In between sit target-module normalization, `_fix_missing_no_grad`,
`_resolve_lora_keys`, and (on the VLM branch) the model-tree walk.

Empirically this leaves a window in which lazy MLX state mutations or
implicit `mx.random` consumption can slip in, so the lora_a matrices
initialized inside `linear_to_lora_layers` end up DIFFERENT from
mlx-lm CLI's, which seeds at `mlx_lm/tuner/lora.py` (def train)
immediately before `linear_to_lora_layers`.

Verified by probe 39 on danielhanchen/unsloth-staging-2:
  seed=    1, 42, 999, 3407, 22222: max |dloss| = max |dgrad_norm| = 0.0
across all 30 steps × 5 seeds (vs non-zero deltas before).

This fix moves `_seed_mlx_random_state(random_state)` to immediately
before each `linear_to_lora_layers(...)` call -- both VLM language
path and text path. API surface unchanged.

Test `tests/test_mlx_get_peft_model_seed_ordering.py` pins:
  1. Every linear_to_lora_layers call inside get_peft_model is
     preceded by `_seed_mlx_random_state` within 20 lines.
  2. The `random_state` API parameter still exists with default 3407.
  3. The tight pairing matches BOTH the VLM and text LoRA call sites
     (regex tripwire that only allows comment lines between).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant