Skip to content

MLX Update Training#684

Open
mmathew23 wants to merge 62 commits into
unslothai:mainfrom
mmathew23:explore/mlx
Open

MLX Update Training#684
mmathew23 wants to merge 62 commits into
unslothai:mainfrom
mmathew23:explore/mlx

Conversation

@mmathew23

@mmathew23 mmathew23 commented May 20, 2026

Copy link
Copy Markdown
Collaborator

Reworks the MLX training path to match the behavior of Unsloth's CUDA/transformers path, with a focus on making VLM fine-tuning work properly on Apple Silicon.

What changed

HF Trainer parity

  • Gradient clipping is now resolved like the CUDA path: max_grad_value takes priority, then per-leaf L2 norm (max_grad_leaf_norm), then global norm (max_grad_norm); an explicit 0 disables clipping.
  • Weight decay follows HF semantics: bias and norm parameters are exempt, AdamW/Muon/Lion use decoupled decay, and SGD couples decay into the gradient before momentum, matching torch.optim.SGD.
  • Data ordering supports torch_randperm with deterministic per-epoch reseeding, so a seeded MLX run sees the same batch order semantics as a CUDA DataLoader run. Epoch boundaries no longer span batches.
  • train_on_responses_only (completion-only training) is implemented for both text and VLM models, mirroring the HF/unsloth API.
  • LoRA+, separate embedding learning rates, and the LR schedules are consistent with the CUDA trainer.

VLM training

  • Label masking for image/video/audio special tokens, with padding masked via the attention mask so EOS targets survive when pad_token_id == eos_token_id.
  • Batch collation through the processor (chat template, image resize cap on the larger side, uniform padding), including Qwen-VL/GLM position id handling.
  • CCE loss through VLM backbones, sliding-window attention masks for Gemma 3/4, fp32 norm stabilization, and NF4 dequantization in the loader.
  • Fused MRoPE Metal kernels have no VJP, so training flips them off for the Qwen2-VL, Qwen2.5-VL and Qwen3-VL families and takes the differentiable cos/sin fallback.

Tests

  • The MLX suite grows to 20+ files (trainer internals, grad clip resolution, batching and decay, VLM label masks, save/export edge cases).
  • New Metal-only e2e tests run real LoRA training on Apple Silicon CI: a tiny 4-bit text model (CCE and baseline losses, default and value clipping, end-of-training save) plus deep behavior checks covering resume_from_checkpoint determinism (stop+resume reproduces the fresh run's losses step for step), completion-only epoch step counts, SGD coupled decay, and a real Qwen2-VL LoRA fit.

Behavior change to be aware of

The default gradient clipping changes from an elementwise clamp at 1.0 to a per-leaf L2 norm clip at 1.0. The new default is closer to standard norm-based clipping and is what the e2e runs validate, but it is a deliberate change from the previous MLX default.

Validation

All on macos-14 (M1, real Metal) plus a Linux CPU sweep of every suite file:

  • Three consecutive fully green runs of the complete suite (VJP regression guards, text LoRA e2e, deep validation, full per-file breadth): run 1, run 2, run 3.
  • The e2e tests caught and fixed two real crashes during validation: train_on_responses_only with mlx-lm's TokenizerWrapper (not callable, so the HF masking impl failed) and Qwen2-VL training hitting the fused MRoPE kernel with no VJP.
  • All review threads are resolved, with each fix pinned by a revert-verified regression test.

@mmathew23 mmathew23 requested a review from danielhanchen as a code owner May 20, 2026 22:08

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces several enhancements and fixes for MLX training, focusing on VLM support and parity with HuggingFace's trainer behavior. Key updates include a manual AdamW weight decay implementation that filters out bias and normalization parameters, a diagnostic 'nf4_dense' quantization mode, and logic to maintain normalization parameters in float32. Additionally, it refines VLM collation, fixes a loss masking off-by-one error, and prevents automatic EOS appending in datasets. Feedback from the review identified a bug in the Qwen3-VL LayerNorm parameter check, precision and memory issues in the manual weight decay logic, and a performance regression in the compiler's logit handling.

Comment thread unsloth_zoo/mlx/compile.py Outdated
Comment on lines +2722 to +2725
if "weight" in norm:
y = y * norm.weight.astype(mx.float32)
if "bias" in norm:
y = y + norm.bias.astype(mx.float32)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The check if "weight" in norm: is not a valid way to verify parameter existence on an mlx.nn.Module. This will likely evaluate to False or raise a TypeError, causing the LayerNorm calculation to skip applying the weight and bias tensors. This will result in incorrect numerical output for the Qwen3-VL vision blocks.

Suggested change
if "weight" in norm:
y = y * norm.weight.astype(mx.float32)
if "bias" in norm:
y = y + norm.bias.astype(mx.float32)
if hasattr(norm, "weight"):
y = y * norm.weight.astype(mx.float32)
if hasattr(norm, "bias"):
y = y + norm.bias.astype(mx.float32)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. The current code already does exactly this: _qwen3_torch_like_layer_norm reads getattr(norm, "weight", None) and getattr(norm, "bias", None) and only applies them when present, so the affine scale and bias are honored for the Qwen3-VL vision blocks. Worth noting for the record that mlx.nn.Module subclasses dict, so even "weight" in norm would have reported parameter presence correctly here.

Comment thread unsloth_zoo/mlx/trainer.py Outdated
Comment on lines +495 to +497
lr = optimizer.learning_rate.astype(flat_grad[name].dtype)
scale = mx.array(1.0, dtype=lr.dtype) - lr * mx.array(wd, dtype=lr.dtype)
decayed.append((name, parameter * scale))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This manual weight decay implementation has two significant issues:

  1. Precision Underflow: Calculating the scale in the parameter's native dtype (e.g., float16 or bfloat16) will cause the weight decay to be ignored. For typical values like lr=2e-4 and wd=0.01, the term lr * wd (2e-6) is smaller than the machine epsilon for float16/bfloat16 relative to 1.0, so 1.0 - 2e-6 rounds back to 1.0.
  2. Unintended Parameter Promotion: If scale is calculated in float32 (to fix the precision issue), the operation parameter * scale will promote the model parameters to float32. Since these parameters (LoRA weights and norms) are explicitly excluded from the restoration logic in _restore_trainable_storage_dtypes, they will remain in float32, doubling their memory footprint for the rest of the training session.

The calculation should be done in float32 and explicitly cast back to the original dtype.

Suggested change
lr = optimizer.learning_rate.astype(flat_grad[name].dtype)
scale = mx.array(1.0, dtype=lr.dtype) - lr * mx.array(wd, dtype=lr.dtype)
decayed.append((name, parameter * scale))
lr = optimizer.learning_rate.astype(mx.float32)
scale = mx.array(1.0, dtype=mx.float32) - lr * mx.array(wd, dtype=mx.float32)
decayed.append((name, (parameter * scale).astype(parameter.dtype)))

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch on both points. The current _apply_manual_weight_decay computes the scale in fp32 and casts the result back to the parameter dtype: (parameter.astype(mx.float32) * scale).astype(parameter.dtype), so there is no float16/bfloat16 underflow and no silent promotion to fp32. This matches the suggested fix.

Comment thread unsloth_zoo/compiler.py Outdated
logit_softcapping = None if (\\4) == () else (\\4),
)
elif self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None and NOT_RETURN_LOGITS:
elif self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Removing the NOT_RETURN_LOGITS check here, combined with the deletion of the specialized branch for returned logits, introduces a performance and memory regression. When UNSLOTH_RETURN_LOGITS=1 is active, logits are materialized early; forcing the code into the unsloth_fused_ce_loss branch causes a redundant matmul. Furthermore, ensure that logit scaling (multiply/divide) and softcapping transformations are applied to these materialized logits before they are passed to the loss function to maintain consistency with standard execution paths.

References
  1. When implementing code paths that return materialized logits (e.g., using UNSLOTH_RETURN_LOGITS=1), ensure that logit scaling and softcapping transformations are applied before the loss function.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The NOT_RETURN_LOGITS guard is still in place, so fused CE only runs when we are not returning logits. There is now a dedicated UNSLOTH_RETURN_LOGITS=1 branch that reuses the already-materialized logits and applies the logit scale-multiply, scale-divide and softcap transforms before calling the loss function, so we avoid the duplicate matmul and keep transform parity.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: a93449f518

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread unsloth_zoo/mlx/utils.py Outdated
Comment on lines +533 to +538
"pad_token_id",
"image_token_id",
"video_token_id",
"audio_token_id",
):
_append_unique_int(ids, getattr(tokenizer, attr, None))

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Exclude pad token ID from global VLM ignore list

Adding pad_token_id to _get_vlm_ignore_token_ids and then masking via _mask_label_token_ids masks that token everywhere, not only in padded positions. For tokenizers where pad_token_id == eos_token_id (common in Llama-derived setups), valid in-sequence EOS targets are turned into -100, so EOS supervision is silently removed across training/eval batches. Padding should be masked only through attention_mask == 0, not by globally ignoring the pad ID.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_get_vlm_ignore_token_ids only collects image/video/audio/boi/eoi token ids; it never adds pad_token_id to the global ignore list. Padding is masked separately via attention_mask == 0, so in-sequence EOS targets are preserved even when pad_token_id == eos_token_id.

Comment thread unsloth_zoo/compiler.py Outdated
logit_softcapping = None if (\\4) == () else (\\4),
)
elif self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None and NOT_RETURN_LOGITS:
elif self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None:

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Keep return-logits mode out of fused CE loss branch

This condition now routes all labeled ForCausalLMLoss calls through unsloth_fused_ce_loss, including UNSLOTH_RETURN_LOGITS=1. In that mode the compiler replacement already materializes logits up front, so this branch computes a second LM-head projection inside fused CE, increasing memory/compute and risking OOM on large models. The previous NOT_RETURN_LOGITS guard avoided that duplicate matmul.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The duplicate matmul concern no longer applies. The fused CE branch is gated by NOT_RETURN_LOGITS, and the UNSLOTH_RETURN_LOGITS=1 path is handled separately by reusing the materialized logits rather than re-projecting inside fused CE.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: dcd0a9001c

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread unsloth_zoo/mlx/compile.py Outdated
Comment on lines +2722 to +2725
if "weight" in norm:
y = y * norm.weight.astype(mx.float32)
if "bias" in norm:
y = y + norm.bias.astype(mx.float32)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Apply LayerNorm affine params in Qwen3 vision block patch

The new _qwen3_torch_like_layer_norm gate checks use membership ("weight" in norm, "bias" in norm) instead of attribute existence. For module-style LayerNorm objects, this test does not reliably indicate parameter presence, so the affine scale/bias can be skipped (or error), making every patched Qwen3-VL vision block run a non-affine normalization that diverges from pretrained/Transformers behavior and can degrade training and loss parity.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current helper uses getattr(norm, "weight", None) and getattr(norm, "bias", None), so the affine parameters are applied. For reference, mlx.nn.Module subclasses dict, so the membership form also reflects parameter presence exactly (LayerNorm with affine has both weight and bias; without affine has neither), so this path runs the correct affine LayerNorm.

Comment thread unsloth_zoo/mlx/utils.py Outdated
Comment on lines +2518 to +2521
if dataset_order == "torch_randperm":
indices = _torch_randperm_order(len(dataset), int(seed) + epoch)
elif dataset_order in (None, "default"):
np.random.seed(int(seed) + epoch)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Normalize nullable seeds before VLM epoch reshuffle

When create_vlm_batches rolls to a new epoch, it unconditionally computes int(seed) + epoch for both torch_randperm and default reshuffling. If callers pass seed=None (which is already supported in _torch_randperm_order and by NumPy seeding), multi-epoch materialization will raise TypeError on the first rollover and stop training/eval batch creation.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

create_vlm_batches now normalizes the seed up front with base_seed = _normalize_seed(seed) (None defaults to 3407) before doing any base_seed + epoch arithmetic, so a seed=None caller no longer raises on epoch rollover.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: e16efc0f9c

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread unsloth_zoo/mlx/utils.py Outdated
Comment on lines +2600 to +2601
if dataset_order == "torch_randperm":
indices = _torch_randperm_order(len(dataset), int(seed) + epoch)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Normalize nullable seed in streaming torch_randperm path

When dataset_order="torch_randperm", the streaming iterator now does int(seed) + epoch before calling _torch_randperm_order. If callers pass seed=None (which _torch_randperm_order explicitly supports by defaulting to 3407), this raises TypeError on the first epoch and stops batch generation. Please normalize seed the same way as _torch_randperm_order (or pass it through directly) before adding epoch.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same fix as the VLM path: the streaming and ordered text paths normalize the seed with _normalize_seed before + epoch, matching _torch_randperm_order having its own None default, so seed=None is handled.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: b0a83b52cf

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread unsloth_zoo/mlx/trainer.py Outdated
continue
if not self._should_apply_weight_decay(name, parameter):
continue
lr = optimizer.learning_rate.astype(flat_grad[name].dtype)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Handle scalar learning rates in manual AdamW decay

_apply_manual_adamw_weight_decay assumes optimizer.learning_rate has an .astype(...) method, but in valid configurations it can be a Python float (notably the constant scheduler path without warmup, and the MLX simulation optimizer stubs). In that case this line raises AttributeError before any optimizer step, so AdamW training fails as soon as manual decay runs. Please normalize LR to an mx.array (or branch on scalar) before dtype-casting.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_apply_manual_weight_decay now branches on whether the learning rate exposes .astype: scalar floats are wrapped via mx.array(lr_value, dtype=mx.float32), so the constant-schedule and simulation-optimizer paths no longer hit an AttributeError.

Comment thread unsloth_zoo/mlx/utils.py Outdated
Comment on lines +2792 to +2793
if not tokenized:
return []

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Raise on empty ordered token stream instead of returning []

When create_ordered_batches drops all rows via if len(ids) >= 2 (for example very small max_seq_length or single-token rows), it returns an empty batch list. The new ordered-data path in MLXTrainer then indexes batches with batch_idx % len(batches), which crashes at runtime with division-by-zero instead of surfacing a data error. This should raise a clear ValueError here, matching the other dataset-prep guards.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

create_ordered_batches now raises a clear ValueError when no trainable sequences survive (need at least two tokens after formatting and truncation), so we surface a data error instead of producing an empty list that later divides by zero.

Comment thread unsloth_zoo/mlx/compile.py Outdated
centered = x_f - mean
var = mx.mean(centered * centered, axis=-1, keepdims=True)
y = centered * mx.rsqrt(var + norm.eps)
if "weight" in norm:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is specific to qwen3 do we need the if check?
or even in qwen3 family some norms have weight/bias and some dont?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the check is needed, and yes, even within the qwen3 family the norms differ. In mlx-vlm qwen3_vl the vision blocks use nn.LayerNorm (weight and bias both present) while the language stack uses nn.RMSNorm (weight only, no bias). Because mlx.nn.Module subclasses dict, "weight" in norm / "bias" in norm report exactly which parameters exist, so the conditional lets the same helper apply bias only where it exists and skip it for RMSNorm. The current code uses getattr(norm, "weight", None) which is equivalent. Keeping the check also keeps the helper correct if an affine=False norm ever appears.

Comment thread unsloth_zoo/mlx/loader.py Outdated
flat = mx.concatenate([flat, mx.zeros((pad,), dtype=mx.float32)])
groups = flat.reshape((-1, group_size))
absmax = mx.max(mx.abs(groups), axis=1, keepdims=True)
denom = mx.maximum(absmax, mx.array(1e-12, dtype=mx.float32))

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking we do this to avoid 0 division. But div by 1e-12 might cause the numbers to blow up?
For a similar thing what I did was set the scale to 1

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, and the current code does exactly what you suggested. The NF4 dequant uses denom = mx.where(absmax > 0, absmax, mx.ones_like(absmax)) (and the same for the nested-scale denom), so zero scales divide by 1 instead of a tiny epsilon. There is no 1e-12 left anywhere in the MLX path, so there is no blow-up risk.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 964be34ec2

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread unsloth_zoo/mlx/utils.py Outdated
Comment on lines +2830 to +2831
if num_batches is None and seen >= len(tokenized):
break

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Keep torch_randperm active across epoch-based text training

When dataset_order="torch_randperm" is used with epoch-based training (max_steps == 0), this early break materializes only one epoch of randomized order and returns. MLXTrainer then reuses that fixed batch list with batch_idx % len(batches) for subsequent epochs, so later epochs never get a fresh randperm order. This diverges from the expected per-epoch reshuffle behavior of a torch-style random sampler and can bias training/eval results for multi-epoch runs.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The text ordered path now reseeds per epoch: make_order(epoch) calls _torch_randperm_order(len, base_seed + epoch) and is re-invoked at each epoch boundary, with the loop bounded by num_epochs * len(dataset), so multi-epoch runs get a fresh permutation each epoch.

Comment thread unsloth_zoo/mlx/utils.py Outdated
Comment on lines +2523 to +2524
if num_batches is None:
break

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Rebuild VLM torch_randperm order for each epoch run

In pre-materialized VLM batching, this branch exits as soon as one dataset pass completes whenever num_batches is None (the normal epoch-based path). With dataset_order="torch_randperm", that means only the first epoch gets randomized; later epochs in MLXTrainer reuse the same stored batches instead of generating a new permutation. This breaks sampler parity for multi-epoch VLM training and can skew convergence behavior.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

create_vlm_batches rebuilds the order on each epoch rollover with _torch_randperm_order(len, base_seed + epoch), so later epochs no longer reuse the first epoch permutation.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: ca08652226

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread unsloth_zoo/mlx/utils.py Outdated
Comment on lines +2825 to +2829
if order_pos >= len(order):
epoch += 1
order = make_order(epoch)
order_pos = 0
batch_items.append(tokenized[order[order_pos]])

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Stop batches from spanning torch_randperm epoch boundaries

create_ordered_batches wraps to a new permutation inside the same batch (order_pos >= len(order)), so when num_epochs is used and len(dataset) % batch_size != 0, one batch can contain samples from two different epochs. This changes both sample order and batch count (e.g., 5 samples, batch size 4, 2 epochs yields 3 batches instead of the expected 4), so MLXTrainer under-trains in epoch mode and no longer mirrors CUDA/DataLoader epoch semantics for dataset_order="torch_randperm".

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

create_ordered_batches now breaks the batch at the epoch boundary: when order_pos >= len(order) it emits the partial batch and restarts at epoch+1 rather than wrapping inside one batch, so no batch spans two epochs and the batch count matches CUDA DataLoader epoch semantics.

@mmathew23

Copy link
Copy Markdown
Collaborator Author

Ok pretty sure resolved review feedback

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 7e0bee546f

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread unsloth_zoo/mlx/trainer.py Outdated
Comment on lines +627 to +628
if any("norm" in part for part in parts):
return False

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Exclude ln_ normalization weights from AdamW decay*

The new HF-style decay filter only checks for path segments containing "norm", so normalization parameters named like ln_1.weight or ln_f.weight still get decayed. This repo already treats ln_* as norm parameters in _ensure_lora_frozen (see _NORM_FRAGMENTS), so models using those names will now receive unintended weight decay on norm scales, diverging from the stated “skip bias and norms” behavior and changing optimization dynamics.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_should_apply_weight_decay uses _part_is_norm, which matches "norm", ln_*, and ln_f, so ln_1.weight / ln_f.weight style norm scales are excluded from decay as intended.

Comment thread unsloth_zoo/mlx/loader.py Outdated
Comment on lines +151 to +152
parts = str(path).lower().split(".")
return any("norm" in part for part in parts[:-1])

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Include ln_ params in fp32 norm-parameter preservation*

_keep_norm_parameters_float32 claims to keep normalization parameters in fp32, but _is_norm_parameter_path only matches components containing "norm". Any normalization layer named with ln_* (which this codebase already recognizes as norm-like elsewhere) is skipped and left in lower precision, undermining the stabilization this pass is meant to provide for FT/LoRA/QLoRA training.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The loader _is_norm_parameter_path matches "norm", ln_*, and ln_f as well, so GPT-2 / GPT-OSS style ln_* normalization parameters are kept in fp32 by _keep_norm_parameters_float32.

@chatgpt-codex-connector

Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

Merges 5 main-side mlx fixes (unslothai#673 zero-token CCE, unslothai#679 + unslothai#692 LoRA save
metadata, unslothai#682 invalid label NaN-poisoning, unslothai#688 tool mask). All 13
conflict regions in unsloth_zoo/mlx/utils.py resolved to keep PR unslothai#684's
behavior where it conflicts on semantics:

  - half-open `<` length mask (PR unslothai#684 fix) wins over main's inclusive `<=`
  - `if labels is None` branch preserved (PR unslothai#684 generality) alongside
    main's `_normalize_cce_label_dtype` dtype widening
  - `_get_image_token_ids` legacy wrapper kept alongside main's new
    `_normalize_cce_label_dtype` / `_normalize_numpy_cce_labels`
  - `_mask_label_token_ids` calls `_normalize_cce_label_dtype` first so
    image masking honors main's uint-widening contract
  - HEAD's `_expand_token_replacements` dropped; main's three-function
    split (`_normalize_numpy_cce_labels` + `_expand_image_token_sequences`
    + `_expand_token_runs`) is canonical; duplicate HEAD wrappers removed
  - `_collate_vlm_prompt_completion_batch` reads back the masked labels
    in int64 so image + attention masking survives without narrowing
  - prompt-completion VLM collator routes through `_apply_vlm_label_masks`
    after dtype normalisation so ignore_token_ids and wide invalid ids
    both reach runtime CCE intact
  - `_to_mx_vlm_batch` uses main's `_normalize_cce_label_dtype` for labels
    while keeping PR unslothai#684's token_type_ids / mm_token_type_ids handling
  - `_unsloth_*` prefix filter preserved so the new collated_position_ids
    flag and main's raw-input-ids carrier both get stripped

152 MLX tests pass post-merge.
@chatgpt-codex-connector

Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

@danielhanchen

Copy link
Copy Markdown
Member

Status pass on this branch relative to today's merges and the open #738, since they share surfaces:

Recommended order: land #738 first (small, validated), then rebase this onto main and resolve the trainer/utils conflicts preferring the merged fixes.

@danielhanchen

Copy link
Copy Markdown
Member

Synced this branch with main via maintainer edit (merge commit e4f4d9f) so it is reviewable against the current tree. Resolution notes:

Verification: AST parse clean on all resolved files, workflow YAML parses, and the CPU suites pass: 16/16 behavioral gates and 38/38 across test_mlx_module_exports, test_mlx_cce_kernel, test_mlx_dequantize_modules, test_mlx_gated_delta.

@mmathew23 please double-check the trainer.py resolution matches your intent, especially the grad-clip comment block and the deduplicated setup section.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: e4f4d9f444

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread unsloth_zoo/mlx/trainer.py Outdated
print("Unsloth: Using standard cross-entropy loss.")

# Prepare data — determine total_steps first
self._prepared_batches_include_epochs = False

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Preserve the prebuilt epoch flag before step counting

In completion-only text training (train_on_responses_only) with num_train_epochs > 1 and max_steps <= 0, _create_labeled_batches already materializes all epochs and sets trainer._prepared_batches_include_epochs = True, but this reset clears that flag before _prepare_data() returns the prebuilt self._batches. The total-step calculation then treats len(batches) as one epoch and multiplies by num_train_epochs again, so a 3-epoch run executes 9 epochs' worth of micro-batches and over-trains/reuses data unexpectedly.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confirmed and fixed. _train_inner was zeroing _prepared_batches_include_epochs before _prepare_data returned the prebuilt batches from train_on_responses_only, and since the prebuilt path returns early it never re-derived the flag, so the step counter multiplied the already-materialized epochs again. The reset is now gated on self._batches is None, which preserves the flag for the completion-only path while normal runs still reset and re-derive it. Added a regression test pinning 18 steps for a 3-epoch run instead of 54. Fixed in c8ad7dd.

Comment thread unsloth_zoo/mlx/trainer.py Outdated
if dataset_order == "torch_randperm":
from .utils import _torch_randperm_order
# Reseed per epoch (matches `create_ordered_batches`).
order = _torch_randperm_order(len(items), seed + epoch_idx)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Normalize nullable seed before labeled randperm ordering

When train_on_responses_only is used with dataset_order="torch_randperm" and a caller leaves seed=None, this addition raises TypeError before _torch_randperm_order can apply its own None default. The unlabeled ordered-batch path already normalizes nullable seeds, so completion-only text training is the path that still fails for the same valid seed configuration.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, fixed. The labeled torch_randperm path did seed + epoch_idx directly, which raised TypeError when seed was None, whereas the unlabeled path already runs _normalize_seed first. The seed is now normalized before the per-epoch add so seed=None falls back to the default and per-epoch reseeding still works. Added tests that build labeled batches with seed=None and assert each epoch is a distinct permutation. Fixed in c8ad7dd.

Comment thread unsloth_zoo/mlx/utils.py
Comment on lines +229 to +230
sliding_window_pattern = getattr(stack, "sliding_window_pattern", None)
window_size = getattr(stack, "window_size", None)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Read Gemma sliding-window settings from config

For Gemma-style VLM language stacks that pass token_type_ids, the sliding-window settings live on stack.config (sliding_window_pattern / sliding_window) rather than as direct stack attributes. These getattr calls therefore resolve to None, so every layer is treated as global and uses the full token-type mask instead of the local sliding-window mask, changing Gemma3 attention semantics and training parity whenever this custom hidden-stack path is used.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch on the robustness angle. On real mlx-vlm 0.6.2 both Gemma3Model and Gemma4TextModel copy window_size and sliding_window_pattern onto the module, so today the masks are correct. A config fallback was added anyway (reading sliding_window and sliding_window_pattern off stack.config) so a stack that only stores them there still builds windowed masks rather than treating every layer as global. Fixed in c8ad7dd.

@danielhanchen

Copy link
Copy Markdown
Member

Validated this branch end to end on real Apple Silicon via the staging fork (danielhanchen/unsloth-zoo-staging, branch staging/mlx-pr-684, macos-14). Summary:

Green across the board: https://github.com/danielhanchen/unsloth-zoo-staging/actions/runs/27351187334

Two commits pushed to this branch on top of the main sync:

  1. 732e7f2 adds the Metal test files (tests/test_qwen35_vjp_metal.py, tests/test_mlx_training_e2e_metal.py) and wires them into mlx-ci.yml as separate steps. They skip loudly off-Metal.
  2. fd8e742 fixes a real gap the e2e test caught: training completed without saving adapters even though the config documents save_steps=0 as "only save at end". _train_inner now calls save_model() at the end with the same ValueError guard as periodic checkpointing. Main has the same gap, so this is an inherited fix, not a regression repair.

Staging PR with the full diff and CI history: danielhanchen#28

From my side this branch is now in good shape: synced with main, the qwen3_5 patch wiring preserved, training proven on Metal, and the full suite green.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: fd8e742b3f

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread unsloth_zoo/mlx/utils.py Outdated
Comment on lines +2473 to +2477
if image.size[0] > image_size:
width, height = image.size
new_width = (width * image_size + width // 2) // width
new_height = (height * image_size + width // 2) // width
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Downscale tall VLM images before processor

When image_size is an int, this only checks/scales by width, so portrait images such as 512×2048 with a 512 limit are left unchanged, and taller images like 1024×4096 are resized to 512×2048. Those oversized images still reach the VLM processor and can generate far more patches than intended, causing avoidable OOMs or batch shape drift; the resize should account for the limiting dimension rather than only image.size[0].

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confirmed. The check only looked at width, so a tall portrait image like 512x2048 under a 512 cap slipped through unscaled. The resize now scales on the larger of width and height, so both portrait and landscape images that exceed the cap shrink; landscape sizing is unchanged. Fixed in c8ad7dd.

Comment thread unsloth_zoo/mlx/utils.py Outdated
if labels is None:
labels = batch_dict["input_ids"].astype(mx.int32)
else:
labels = labels.astype(mx.int32)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve wide VLM labels until CCE validation

When the collator supplies labels, this cast narrows them to int32 before _vlm_cce_forward sends them to runtime CCE. For the same unsigned/wide invalid IDs this file otherwise preserves via _RAW_INPUT_IDS_FOR_LABELS (e.g. values that wrap to -100 on int32), the invalid target is silently converted into an ignore label instead of reaching CCE's range check, so bad processor outputs can be hidden and train/eval loss becomes incorrect.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, the int32 cast here was undoing the wide-id preservation the rest of the file sets up via _RAW_INPUT_IDS_FOR_LABELS and _normalize_cce_label_dtype. A uint32 id near 2**32 wrapped on the narrow and silently turned into an ignore position instead of reaching CCE as an out-of-vocab sentinel. Labels are now widened instead of narrowed and derived from the raw pre-narrow carrier, while padding and ignore tokens are still masked. Fixed in c8ad7dd.

Comment thread unsloth_zoo/mlx/utils.py
video_token_id=int(_config_get(config, "video_token_id", _config_get(config, "video_token_index"))),
spatial_merge_size=int(vision_config.get("spatial_merge_size", 2)),
)
batch_dict["_unsloth_collated_position_ids"] = True

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Strip collator-only position flag before VLM forward

For Qwen/GLM batches this private marker is added to the batch, but the non-CCE VLM loss still forwards every extra key except input_ids/pixel_values/attention_mask/labels/the raw-label carrier into model(...). With use_cce=False, _unsloth_collated_position_ids=True is therefore passed as an unexpected model kwarg for these architectures, while the CCE path already filters private _unsloth_ keys; the baseline path needs the same filtering.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confirmed. The CCE path strips every _unsloth_* carrier, but the baseline (use_cce=False) path only dropped a fixed key list, so the _unsloth_collated_position_ids marker leaked into model(...) and would TypeError on Qwen/GLM forwards that reject unknown kwargs. The baseline filter now uses the same _unsloth_ prefix strip so the two paths stay in parity; position_ids still passes through. Fixed in c8ad7dd.

Comment thread unsloth_zoo/mlx/loader.py
Comment on lines +863 to +864
if "weight" in self:
y = y * (1.0 + self.weight.astype(mx.float32))

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Apply Gemma3 norm affine parameters by attribute

In the Gemma3 fp32 norm patches, testing membership with "weight" in self/"bias" in norm is not a reliable way to detect MLX module parameters; for module-style norms this can skip the affine scale/bias (or error) even though the attributes exist. Since these patches are installed on Gemma3 load, affected runs either diverge from the pretrained LayerNorm/RMSNorm math or fail during the first forward; use attribute checks like getattr(norm, "weight", None) here and in the analogous vision LayerNorm helpers.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self and norm here are mlx nn.Module instances (Gemma3 RMSNorm and SigLIP LayerNorm from mlx-vlm), and mlx.nn.Module subclasses dict, so "weight" in self / "bias" in norm are membership tests over the registered parameters and report presence exactly. Verified on CPU: LayerNorm with affine reports both weight and bias, RMSNorm reports weight and not bias, affine=False reports neither. So the affine scale and bias are applied correctly and there is no divergence or first-forward failure.

Comment on lines 690 to +693
elif opt_name == "sgd":
optimizer = optim.SGD(learning_rate=initial_lr, weight_decay=wd)
# HF parity: manual bias/norm-aware decoupled decay.
self._manual_weight_decay = float(wd or 0.0)
optimizer = optim.SGD(learning_rate=initial_lr, weight_decay=0.0)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Keep SGD weight decay coupled to the gradient

When optim="sgd" and weight_decay>0, this now disables the optimizer's built-in decay and applies the AdamW-style decoupled parameter shrink before optimizer.update. SGD weight decay in the HF/PyTorch optimizer is coupled into the gradient (and therefore into momentum/Nesterov state), so this changes training dynamics for SGD runs instead of providing the stated HF parity; keep built-in SGD decay or branch manual decay by optimizer semantics.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, fixed. SGD was getting AdamW-style decoupled decay (param shrink after the step), but PyTorch SGD and MLX SGD both couple weight decay into the gradient before momentum, so this changed SGD dynamics rather than matching HF. There is now a coupled-decay path that folds wd * param into the post-clip gradient for non-bias/non-norm leaves and SGD routes to it, keeping MLX SGD built-in decay off and preserving the HF bias/norm exemption. A new test steps MLX SGD with momentum and matches torch.optim.SGD(weight_decay=wd) numerically, and another pins that the old decoupled shrink diverges. Fixed in c8ad7dd.

Comment thread unsloth_zoo/mlx/loader.py Outdated
denom = mx.where(absmax > 0, absmax, mx.ones_like(absmax))
scaled = groups / denom
indices = mx.argmin(mx.abs(scaled[..., None] - codebook), axis=-1)
absmax = _bnb_nested_absmax(absmax.reshape((-1,))).reshape((-1, 1))

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Keep NF4 scales un-nested unless requested

For q_mode="nf4_dense", every group absmax is passed through _bnb_nested_absmax, which simulates nested/double quantization of the scale values. The allowed BitsAndBytes config path explicitly rejects bnb_4bit_use_double_quant=True, so the default/accepted NF4 case should use the raw per-group absmax; with varying group magnitudes this extra scale quantization changes the dense dequantized weights and breaks the intended CUDA NF4 parity.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. The accepted BitsAndBytesConfig path rejects bnb_4bit_use_double_quant=True, and CUDA bitsandbytes only applies the nested absmax when a nested quant state is present, so default NF4 should keep plain scales. The dense dequant was always nesting; it is now gated behind a use_double_quant flag that defaults off, so default NF4 matches the single-quant reference while the nested simulation stays reachable explicitly. Fixed in c8ad7dd.

@danielhanchen

Copy link
Copy Markdown
Member

Correctness verification round complete. All 28 review threads now have replies with verdicts and evidence; summary:

  • 14 of the earlier findings were already fixed on this branch (verified against the current head, each reply cites the code).
  • 3 were invalid, all variants of the same claim that "weight" in norm membership tests are unreliable on MLX modules. mlx.nn.Module subclasses dict, so membership reports registered parameter presence exactly (verified empirically: LayerNorm affine yields weight and bias, affine=False yields neither, RMSNorm has weight only).
  • 9 were valid against the current head and are fixed in c8ad7dd, each with a regression test that fails when the fix is reverted: the prebuilt-epoch flag clobber (a 3-epoch completion-only run executed 9 epochs of steps), the labeled-path seed=None TypeError, step-based runs materializing all epoch blocks, SGD weight decay now coupled into the gradient like torch and MLX SGD (with the HF bias/norm exemption kept; numerically matched against torch.optim.SGD with momentum), the fast-tokenizer Rust-backend unwrap hazard in two places, portrait image downscaling, wide label preservation through VLM masking, the collator marker leaking into non-CCE VLM forwards, and NF4 plain absmax for the accepted non-double-quant config.

Validation on Apple Silicon (staging fork, macos-14): https://github.com/danielhanchen/unsloth-zoo-staging/actions/runs/27354619482 is fully green with the fixes: qwen3_5 VJP regressions, the real LoRA training e2e (loss improves, adapters save at end), and the full MLX suite breadth pass, one pytest process per file. CPU-side, the combined suites pass 148 tests plus the full per-file sweep.

One review note for whoever lands this: the default grad clip changed from main's elementwise 1.0 to a per-leaf L2 norm cap of 1.0. The resolution priority is coherent (explicit max_grad_value wins, then max_grad_leaf_norm, then max_grad_norm for HF-style global norm, explicit 0 disables) and it is direction-preserving and memory-cheap, so I think it is the right default, but it is a deliberate behavior change worth knowing about.

@danielhanchen

Copy link
Copy Markdown
Member

Synced with main again and resolved the conflicts in 819bc46. Two MLX changes landed on main since the last sync:

  • fix(mlx): disable fused MRoPE for Qwen3-VL training to allow VJP #750 (disable fused MRoPE for Qwen3-VL training): the branch relocates the model setup inside the try/finally in train(), so the new qwen3_vl block was ported into that section right after the existing qwen3_5 wiring. Same behavior, one location.
  • feat(mlx): implement resume_from_checkpoint for MLX training #751 (resume_from_checkpoint): merged cleanly apart from the setup hunk. The resume block sits after _build_optimizer, the loop fast-forward and batch_idx offset are in the reworked loop, and the checkpoint writes for optimizer plus trainer state are in the save path. The four state helpers in utils.py auto-merged.

Validation on the merge commit:

PR is mergeable again.

@danielhanchen

Copy link
Copy Markdown
Member

Synced with main once more in 055a7f4. Only #763 (llama.cpp CMake-only install fix) had landed since the last sync; the merge was clean with no conflicts and does not touch the MLX surface.

Re-validation on the new head:

Also re-checked the diff surface against main: only unsloth_zoo/mlx/*, unsloth_zoo/mlx/cce/runtime_cce.py, the MLX test files, and the two workflow updates (renamed test reference in consolidated-tests-ci.yml plus the two new Metal test steps in mlx-ci.yml). No unrelated files.

No new review feedback since the last round; all 28 threads remain answered. PR is mergeable.

danielhanchen added a commit to danielhanchen/unsloth-zoo-staging that referenced this pull request Jun 12, 2026
danielhanchen added a commit to danielhanchen/unsloth-zoo-staging that referenced this pull request Jun 12, 2026

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 0120a7c5af

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +761 to +766
lr_value = optimizer.learning_rate
if hasattr(lr_value, "astype"):
lr = lr_value.astype(mx.float32)
else:
lr = mx.array(lr_value, dtype=mx.float32)
scale = mx.array(1.0, dtype=mx.float32) - lr * mx.array(wd, dtype=mx.float32)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Apply per-parameter LR to decoupled decay

When lora_plus_ratio or embedding_learning_rate is used with AdamW/Muon/Lion and weight_decay > 0, gradients are scaled per leaf in _grad_leaf_scale, but this manual decoupled decay still shrinks every decayed parameter using only the optimizer's single base LR. Decoupled weight decay is multiplied by the parameter group's effective LR, so lora_b and embedding/lm_head parameters are under- or over-decayed relative to the LR mode being requested; carry the same per-leaf LR multiplier into this scale.

Useful? React with 👍 / 👎.

Comment thread unsloth_zoo/mlx/utils.py
if not isinstance(content, list):
continue
for part in content:
if isinstance(part, dict) and part.get("type") == "image":

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Count image_url parts when repairing DeepSeek prompts

For DeepSeek-style processors this repair decides whether to prepend missing image tokens from _count_vlm_image_parts, but the counter only recognizes parts whose type is exactly "image". Datasets using the common VLM schema {"type": "image_url", ...} or input_image can still supply images through the item/process_vision_info paths, yet image_count stays zero, so the rendered prompt is sent to the processor with too few image placeholders and can fail or misalign images with text.

Useful? React with 👍 / 👎.

Comment thread unsloth_zoo/mlx/utils.py
is_global = not sliding_window_pattern or (
i % sliding_window_pattern == sliding_window_pattern - 1
)
local_mask = token_type_mask if is_global else sliding_token_type_mask

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Fall back when sliding window mask is unavailable

If a Gemma-style stack/config exposes sliding_window_pattern but does not expose window_size/sliding_window (or sets it to 0), sliding_token_type_mask remains None; every non-global layer then receives None here instead of the causal/token-type mask. In that configuration the local layers lose both causal and padding masking during CCE hidden-stack execution, so this branch should only select the sliding mask when it was actually built, otherwise fall back to token_type_mask.

Useful? React with 👍 / 👎.

@danielhanchen

Copy link
Copy Markdown
Member

Ran a deeper behavioral validation round on real Apple Silicon before we land this. Added tests/test_mlx_pr684_full_validation_metal.py, which proves on Metal the paths the unit suite could only simulate:

  • resume_from_checkpoint determinism: a stop+resume run reproduces the fresh run's losses step for step (restored Adam moments, batch fast-forward, LR schedule offset).
  • train_on_responses_only completion-only training: a 3 epoch run executes exactly 18 steps through the labeled-batch path.
  • Epoch-based runs: num_train_epochs drives the step count when max_steps is disabled.
  • SGD with gradient-coupled weight decay end to end.
  • Real VLM LoRA training: Qwen2-VL-2B-Instruct-4bit through collation, label masking, VLM CCE, and adapter save.

The new tests caught two real bugs, both fixed:

  1. train_on_responses_only crashed with TypeError: 'TokenizerWrapper' object is not callable (b49a056). mlx-lm's TokenizerWrapper proxies attributes via __getattr__, so the unwrap guard saw convert_tokens_to_ids and kept the wrapper, but the HF masking implementation calls tokenizer(...) and the wrapper defines no __call__. The guard now requires the object to be callable as well; a revert-verified regression test covers a proxying non-callable wrapper and a plain fast tokenizer.
  2. Qwen2-VL training crashed with ValueError: [Primitive::vjp] Not implemented for CustomKernel (0120a7c). mlx-vlm's qwen2_vl language tower inherits the same fused MRoPE Metal kernel that has no gradient, but the trainer only disabled it for qwen3_5 and qwen3_vl. The gate now covers qwen2_vl and qwen2_5_vl too, so training takes the differentiable cos/sin fallback. The Metal e2e trains real Qwen2-VL as the regression pin.

Validation state, all on macos-14 M1 with real Metal:

Branch remains synced with main and mergeable.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: cd8efb9a4f

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread unsloth_zoo/mlx/utils.py
# Qwen3-VL first-step loss from ~6.45 to ~6.90 on the real-cat fixture.
if use_collated_position_ids and "position_ids" in extra_kwargs:
backbone_kwargs["position_ids"] = extra_kwargs["position_ids"]
if "token_type_ids" in extra_kwargs:

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Route mm_token_type_ids into the hidden stack

When a VLM processor returns only mm_token_type_ids (the Gemma 4-style key that this collator preserves and requests), this CCE path never aliases it to the token_type_ids key consumed by _run_hidden_stack. In hidden-stack CCE runs, those batches therefore fall back to the plain causal/available mask instead of the Gemma image token-type/sliding mask, changing attention semantics for processors that emit mm_token_type_ids rather than token_type_ids.

Useful? React with 👍 / 👎.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants