Skip to content

fix(mlx): seed mx.random immediately before linear_to_lora_layers (re-PR of #674)#678

Merged
danielhanchen merged 1 commit into
mainfrom
fix-mlx-seed-ordering-rebased
May 19, 2026
Merged

fix(mlx): seed mx.random immediately before linear_to_lora_layers (re-PR of #674)#678
danielhanchen merged 1 commit into
mainfrom
fix-mlx-seed-ordering-rebased

Conversation

@danielhanchen

Copy link
Copy Markdown
Member

Summary

PR #674 was accidentally merged into the stale fix-mlx-num-layers-parity branch (which had already been squashed-and-superseded into main as #669), so the seed-ordering fix never reached main. This PR re-applies that fix cleanly against the current main.

What this fixes

FastMLXModel.get_peft_model previously called _seed_mlx_random_state(random_state) near the top of the method, ~100+ source lines above the actual linear_to_lora_layers call. Between them sit target_modules normalization, _fix_missing_no_grad, _resolve_lora_keys, and (on the VLM branch) model-tree walks. MLX's lazy evaluation can advance mx.random.state between the early seed call and lora_a init via mx.random.uniform, producing lora_a matrices that differ from mlx-lm CLI's despite both paths seeding to the same int.

This PR moves _seed_mlx_random_state(random_state) to immediately before each linear_to_lora_layers(...) call (both VLM language path and text path). Matches mlx-lm's own ordering in mlx_lm/tuner/lora.py (the seed is the last thing it does before linear_to_lora_layers).

Verification

Probe 39 on danielhanchen/unsloth-staging-2 (5 seeds × 30 steps, gemma-3-270m-it, FastMLXModel vs mlx-lm CLI through the same manual training loop):

seed max |dloss| max |dgrad_norm|
1 0.0 0.0
42 0.0 0.0
999 0.0 0.0
3407 0.0 0.0
22222 0.0 0.0

Bit-identical losses AND gradient norms across all 30 steps × 5 seeds. Before the fix, lora_a values diverged from step 1.

Test plan

  • pytest tests/test_mlx_get_peft_model_seed_ordering.py → 3/3 pass (source-pinning tripwires).
  • Numerical canary on Apple Silicon via Round BR probe 39.
  • Reviewer to confirm CI green.

Stranded branch

For history: the original PR #674 lives in fix-mlx-num-layers-parity at commit 74902b2a. That branch is now obsolete and can be deleted.

Re-merge of PR #674 — the original was accidentally merged into the
stale fix-mlx-num-layers-parity branch (after #669 had already squashed
into main), leaving this fix stranded.

FastMLXModel.get_peft_model previously called
`_seed_mlx_random_state(random_state)` near the top of the method,
~100+ source lines above the actual `linear_to_lora_layers` call.
In between sit target-module normalization, `_fix_missing_no_grad`,
`_resolve_lora_keys`, and (on the VLM branch) the model-tree walk.

Empirically this leaves a window in which lazy MLX state mutations or
implicit `mx.random` consumption can slip in, so the lora_a matrices
initialized inside `linear_to_lora_layers` end up DIFFERENT from
mlx-lm CLI's, which seeds at `mlx_lm/tuner/lora.py` (def train)
immediately before `linear_to_lora_layers`.

Verified by probe 39 on danielhanchen/unsloth-staging-2:
  seed=    1, 42, 999, 3407, 22222: max |dloss| = max |dgrad_norm| = 0.0
across all 30 steps × 5 seeds (vs non-zero deltas before).

This fix moves `_seed_mlx_random_state(random_state)` to immediately
before each `linear_to_lora_layers(...)` call -- both VLM language
path and text path. API surface unchanged.

Test `tests/test_mlx_get_peft_model_seed_ordering.py` pins:
  1. Every linear_to_lora_layers call inside get_peft_model is
     preceded by `_seed_mlx_random_state` within 20 lines.
  2. The `random_state` API parameter still exists with default 3407.
  3. The tight pairing matches BOTH the VLM and text LoRA call sites
     (regex tripwire that only allows comment lines between).

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: b515a77803

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread unsloth_zoo/mlx/loader.py
# Match mlx_lm/tuner/lora.py (def train) -- seed
# mx.random immediately before LoRA init; lazy MLX
# state advances otherwise leak into lora_a sampling.
_seed_mlx_random_state(random_state)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Seed VLM-only LoRA branches before wrapping

When a VLM call sets finetune_language_layers=False but enables train_vision and/or train_projector, execution skips this newly relocated seed and then _lora_walk_module wraps those branches with LoRALinear.from_base, so their LoRA matrices are initialized from whatever global MLX RNG state happens to exist and random_state is ignored. The removed top-level seed used to cover these supported VLM-only fine-tuning modes; add an immediate seed before the vision/projector _lora_walk_module calls as well, or keep a fallback seed for paths that do not call linear_to_lora_layers.

Useful? React with 👍 / 👎.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the MLX random state seeding logic in get_peft_model by moving the _seed_mlx_random_state call from a global position to immediately before each linear_to_lora_layers invocation. This change prevents lazy MLX state leakage, ensuring deterministic LoRA initialization. New tests were added to verify this seeding order and maintain API surface consistency. Feedback suggests ensuring that all paths, including VLM-specific branches, are covered by this seeding logic and improving the robustness of the test assertions by comparing the seed-call count against the total number of LoRA layer initializations.

I am having trouble creating individual review comments. Click here to see my feedback.

unsloth_zoo/mlx/loader.py (2744-2746)

high

Removing the global _seed_mlx_random_state(random_state) call at the start of get_peft_model ensures that the language LoRA initialization is clean, but it leaves other paths—specifically train_vision and train_projector in the VLM branch—without a deterministic seed if finetune_language_layers is False. Since those call sites are not included in this diff, please ensure they are also updated to include a preceding seed call to maintain full reproducibility across all PEFT configurations.

tests/test_mlx_get_peft_model_seed_ordering.py (83-90)

medium

The assertion assert len(matches) >= 2 is weak because it doesn't verify that all calls to linear_to_lora_layers are correctly seeded. If a new call were added without a seed, this test would still pass as long as the existing two remain. It would be more robust to compare the number of matches against the total number of linear_to_lora_layers calls found in the source.

    matches = pattern.findall(src)
    lines = src.splitlines()
    call_count = len([l for l in lines if "linear_to_lora_layers(" in l and not l.strip().startswith("#")])
    assert len(matches) == call_count, (
        f"Expected {call_count} seed+LoRA-call tight pairings in "
        f"get_peft_model, but found {len(matches)}. Every "
        "linear_to_lora_layers invocation must be immediately preceded by a _seed_mlx_random_state call.
    )
References
  1. Verify the fix using tests that check for line structure preservation.
  2. It is acceptable to use fragile string-matching for code patching if it is consistent with the existing codebase's architecture.

@danielhanchen danielhanchen merged commit 3116c54 into main May 19, 2026
1 of 14 checks passed
@danielhanchen danielhanchen deleted the fix-mlx-seed-ordering-rebased branch May 19, 2026 12:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant