MLX Update Training#684
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces several enhancements and fixes for MLX training, focusing on VLM support and parity with HuggingFace's trainer behavior. Key updates include a manual AdamW weight decay implementation that filters out bias and normalization parameters, a diagnostic 'nf4_dense' quantization mode, and logic to maintain normalization parameters in float32. Additionally, it refines VLM collation, fixes a loss masking off-by-one error, and prevents automatic EOS appending in datasets. Feedback from the review identified a bug in the Qwen3-VL LayerNorm parameter check, precision and memory issues in the manual weight decay logic, and a performance regression in the compiler's logit handling.
| if "weight" in norm: | ||
| y = y * norm.weight.astype(mx.float32) | ||
| if "bias" in norm: | ||
| y = y + norm.bias.astype(mx.float32) |
There was a problem hiding this comment.
The check if "weight" in norm: is not a valid way to verify parameter existence on an mlx.nn.Module. This will likely evaluate to False or raise a TypeError, causing the LayerNorm calculation to skip applying the weight and bias tensors. This will result in incorrect numerical output for the Qwen3-VL vision blocks.
| if "weight" in norm: | |
| y = y * norm.weight.astype(mx.float32) | |
| if "bias" in norm: | |
| y = y + norm.bias.astype(mx.float32) | |
| if hasattr(norm, "weight"): | |
| y = y * norm.weight.astype(mx.float32) | |
| if hasattr(norm, "bias"): | |
| y = y + norm.bias.astype(mx.float32) |
There was a problem hiding this comment.
Thanks. The current code already does exactly this: _qwen3_torch_like_layer_norm reads getattr(norm, "weight", None) and getattr(norm, "bias", None) and only applies them when present, so the affine scale and bias are honored for the Qwen3-VL vision blocks. Worth noting for the record that mlx.nn.Module subclasses dict, so even "weight" in norm would have reported parameter presence correctly here.
| lr = optimizer.learning_rate.astype(flat_grad[name].dtype) | ||
| scale = mx.array(1.0, dtype=lr.dtype) - lr * mx.array(wd, dtype=lr.dtype) | ||
| decayed.append((name, parameter * scale)) |
There was a problem hiding this comment.
This manual weight decay implementation has two significant issues:
- Precision Underflow: Calculating the
scalein the parameter's native dtype (e.g.,float16orbfloat16) will cause the weight decay to be ignored. For typical values likelr=2e-4andwd=0.01, the termlr * wd(2e-6) is smaller than the machine epsilon forfloat16/bfloat16relative to 1.0, so1.0 - 2e-6rounds back to1.0. - Unintended Parameter Promotion: If
scaleis calculated infloat32(to fix the precision issue), the operationparameter * scalewill promote the model parameters tofloat32. Since these parameters (LoRA weights and norms) are explicitly excluded from the restoration logic in_restore_trainable_storage_dtypes, they will remain infloat32, doubling their memory footprint for the rest of the training session.
The calculation should be done in float32 and explicitly cast back to the original dtype.
| lr = optimizer.learning_rate.astype(flat_grad[name].dtype) | |
| scale = mx.array(1.0, dtype=lr.dtype) - lr * mx.array(wd, dtype=lr.dtype) | |
| decayed.append((name, parameter * scale)) | |
| lr = optimizer.learning_rate.astype(mx.float32) | |
| scale = mx.array(1.0, dtype=mx.float32) - lr * mx.array(wd, dtype=mx.float32) | |
| decayed.append((name, (parameter * scale).astype(parameter.dtype))) |
There was a problem hiding this comment.
Good catch on both points. The current _apply_manual_weight_decay computes the scale in fp32 and casts the result back to the parameter dtype: (parameter.astype(mx.float32) * scale).astype(parameter.dtype), so there is no float16/bfloat16 underflow and no silent promotion to fp32. This matches the suggested fix.
| logit_softcapping = None if (\\4) == () else (\\4), | ||
| ) | ||
| elif self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None and NOT_RETURN_LOGITS: | ||
| elif self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None: |
There was a problem hiding this comment.
Removing the NOT_RETURN_LOGITS check here, combined with the deletion of the specialized branch for returned logits, introduces a performance and memory regression. When UNSLOTH_RETURN_LOGITS=1 is active, logits are materialized early; forcing the code into the unsloth_fused_ce_loss branch causes a redundant matmul. Furthermore, ensure that logit scaling (multiply/divide) and softcapping transformations are applied to these materialized logits before they are passed to the loss function to maintain consistency with standard execution paths.
References
- When implementing code paths that return materialized logits (e.g., using UNSLOTH_RETURN_LOGITS=1), ensure that logit scaling and softcapping transformations are applied before the loss function.
There was a problem hiding this comment.
The NOT_RETURN_LOGITS guard is still in place, so fused CE only runs when we are not returning logits. There is now a dedicated UNSLOTH_RETURN_LOGITS=1 branch that reuses the already-materialized logits and applies the logit scale-multiply, scale-divide and softcap transforms before calling the loss function, so we avoid the duplicate matmul and keep transform parity.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: a93449f518
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| "pad_token_id", | ||
| "image_token_id", | ||
| "video_token_id", | ||
| "audio_token_id", | ||
| ): | ||
| _append_unique_int(ids, getattr(tokenizer, attr, None)) |
There was a problem hiding this comment.
Exclude pad token ID from global VLM ignore list
Adding pad_token_id to _get_vlm_ignore_token_ids and then masking via _mask_label_token_ids masks that token everywhere, not only in padded positions. For tokenizers where pad_token_id == eos_token_id (common in Llama-derived setups), valid in-sequence EOS targets are turned into -100, so EOS supervision is silently removed across training/eval batches. Padding should be masked only through attention_mask == 0, not by globally ignoring the pad ID.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
_get_vlm_ignore_token_ids only collects image/video/audio/boi/eoi token ids; it never adds pad_token_id to the global ignore list. Padding is masked separately via attention_mask == 0, so in-sequence EOS targets are preserved even when pad_token_id == eos_token_id.
| logit_softcapping = None if (\\4) == () else (\\4), | ||
| ) | ||
| elif self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None and NOT_RETURN_LOGITS: | ||
| elif self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None: |
There was a problem hiding this comment.
Keep return-logits mode out of fused CE loss branch
This condition now routes all labeled ForCausalLMLoss calls through unsloth_fused_ce_loss, including UNSLOTH_RETURN_LOGITS=1. In that mode the compiler replacement already materializes logits up front, so this branch computes a second LM-head projection inside fused CE, increasing memory/compute and risking OOM on large models. The previous NOT_RETURN_LOGITS guard avoided that duplicate matmul.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
The duplicate matmul concern no longer applies. The fused CE branch is gated by NOT_RETURN_LOGITS, and the UNSLOTH_RETURN_LOGITS=1 path is handled separately by reusing the materialized logits rather than re-projecting inside fused CE.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: dcd0a9001c
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| if "weight" in norm: | ||
| y = y * norm.weight.astype(mx.float32) | ||
| if "bias" in norm: | ||
| y = y + norm.bias.astype(mx.float32) |
There was a problem hiding this comment.
Apply LayerNorm affine params in Qwen3 vision block patch
The new _qwen3_torch_like_layer_norm gate checks use membership ("weight" in norm, "bias" in norm) instead of attribute existence. For module-style LayerNorm objects, this test does not reliably indicate parameter presence, so the affine scale/bias can be skipped (or error), making every patched Qwen3-VL vision block run a non-affine normalization that diverges from pretrained/Transformers behavior and can degrade training and loss parity.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
The current helper uses getattr(norm, "weight", None) and getattr(norm, "bias", None), so the affine parameters are applied. For reference, mlx.nn.Module subclasses dict, so the membership form also reflects parameter presence exactly (LayerNorm with affine has both weight and bias; without affine has neither), so this path runs the correct affine LayerNorm.
| if dataset_order == "torch_randperm": | ||
| indices = _torch_randperm_order(len(dataset), int(seed) + epoch) | ||
| elif dataset_order in (None, "default"): | ||
| np.random.seed(int(seed) + epoch) |
There was a problem hiding this comment.
Normalize nullable seeds before VLM epoch reshuffle
When create_vlm_batches rolls to a new epoch, it unconditionally computes int(seed) + epoch for both torch_randperm and default reshuffling. If callers pass seed=None (which is already supported in _torch_randperm_order and by NumPy seeding), multi-epoch materialization will raise TypeError on the first rollover and stop training/eval batch creation.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
create_vlm_batches now normalizes the seed up front with base_seed = _normalize_seed(seed) (None defaults to 3407) before doing any base_seed + epoch arithmetic, so a seed=None caller no longer raises on epoch rollover.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: e16efc0f9c
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| if dataset_order == "torch_randperm": | ||
| indices = _torch_randperm_order(len(dataset), int(seed) + epoch) |
There was a problem hiding this comment.
Normalize nullable seed in streaming torch_randperm path
When dataset_order="torch_randperm", the streaming iterator now does int(seed) + epoch before calling _torch_randperm_order. If callers pass seed=None (which _torch_randperm_order explicitly supports by defaulting to 3407), this raises TypeError on the first epoch and stops batch generation. Please normalize seed the same way as _torch_randperm_order (or pass it through directly) before adding epoch.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Same fix as the VLM path: the streaming and ordered text paths normalize the seed with _normalize_seed before + epoch, matching _torch_randperm_order having its own None default, so seed=None is handled.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: b0a83b52cf
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| continue | ||
| if not self._should_apply_weight_decay(name, parameter): | ||
| continue | ||
| lr = optimizer.learning_rate.astype(flat_grad[name].dtype) |
There was a problem hiding this comment.
Handle scalar learning rates in manual AdamW decay
_apply_manual_adamw_weight_decay assumes optimizer.learning_rate has an .astype(...) method, but in valid configurations it can be a Python float (notably the constant scheduler path without warmup, and the MLX simulation optimizer stubs). In that case this line raises AttributeError before any optimizer step, so AdamW training fails as soon as manual decay runs. Please normalize LR to an mx.array (or branch on scalar) before dtype-casting.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
_apply_manual_weight_decay now branches on whether the learning rate exposes .astype: scalar floats are wrapped via mx.array(lr_value, dtype=mx.float32), so the constant-schedule and simulation-optimizer paths no longer hit an AttributeError.
| if not tokenized: | ||
| return [] |
There was a problem hiding this comment.
Raise on empty ordered token stream instead of returning []
When create_ordered_batches drops all rows via if len(ids) >= 2 (for example very small max_seq_length or single-token rows), it returns an empty batch list. The new ordered-data path in MLXTrainer then indexes batches with batch_idx % len(batches), which crashes at runtime with division-by-zero instead of surfacing a data error. This should raise a clear ValueError here, matching the other dataset-prep guards.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
create_ordered_batches now raises a clear ValueError when no trainable sequences survive (need at least two tokens after formatting and truncation), so we surface a data error instead of producing an empty list that later divides by zero.
| centered = x_f - mean | ||
| var = mx.mean(centered * centered, axis=-1, keepdims=True) | ||
| y = centered * mx.rsqrt(var + norm.eps) | ||
| if "weight" in norm: |
There was a problem hiding this comment.
If this is specific to qwen3 do we need the if check?
or even in qwen3 family some norms have weight/bias and some dont?
There was a problem hiding this comment.
Yes, the check is needed, and yes, even within the qwen3 family the norms differ. In mlx-vlm qwen3_vl the vision blocks use nn.LayerNorm (weight and bias both present) while the language stack uses nn.RMSNorm (weight only, no bias). Because mlx.nn.Module subclasses dict, "weight" in norm / "bias" in norm report exactly which parameters exist, so the conditional lets the same helper apply bias only where it exists and skip it for RMSNorm. The current code uses getattr(norm, "weight", None) which is equivalent. Keeping the check also keeps the helper correct if an affine=False norm ever appears.
| flat = mx.concatenate([flat, mx.zeros((pad,), dtype=mx.float32)]) | ||
| groups = flat.reshape((-1, group_size)) | ||
| absmax = mx.max(mx.abs(groups), axis=1, keepdims=True) | ||
| denom = mx.maximum(absmax, mx.array(1e-12, dtype=mx.float32)) |
There was a problem hiding this comment.
I'm thinking we do this to avoid 0 division. But div by 1e-12 might cause the numbers to blow up?
For a similar thing what I did was set the scale to 1
There was a problem hiding this comment.
Agreed, and the current code does exactly what you suggested. The NF4 dequant uses denom = mx.where(absmax > 0, absmax, mx.ones_like(absmax)) (and the same for the nested-scale denom), so zero scales divide by 1 instead of a tiny epsilon. There is no 1e-12 left anywhere in the MLX path, so there is no blow-up risk.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 964be34ec2
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| if num_batches is None and seen >= len(tokenized): | ||
| break |
There was a problem hiding this comment.
Keep torch_randperm active across epoch-based text training
When dataset_order="torch_randperm" is used with epoch-based training (max_steps == 0), this early break materializes only one epoch of randomized order and returns. MLXTrainer then reuses that fixed batch list with batch_idx % len(batches) for subsequent epochs, so later epochs never get a fresh randperm order. This diverges from the expected per-epoch reshuffle behavior of a torch-style random sampler and can bias training/eval results for multi-epoch runs.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
The text ordered path now reseeds per epoch: make_order(epoch) calls _torch_randperm_order(len, base_seed + epoch) and is re-invoked at each epoch boundary, with the loop bounded by num_epochs * len(dataset), so multi-epoch runs get a fresh permutation each epoch.
| if num_batches is None: | ||
| break |
There was a problem hiding this comment.
Rebuild VLM torch_randperm order for each epoch run
In pre-materialized VLM batching, this branch exits as soon as one dataset pass completes whenever num_batches is None (the normal epoch-based path). With dataset_order="torch_randperm", that means only the first epoch gets randomized; later epochs in MLXTrainer reuse the same stored batches instead of generating a new permutation. This breaks sampler parity for multi-epoch VLM training and can skew convergence behavior.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
create_vlm_batches rebuilds the order on each epoch rollover with _torch_randperm_order(len, base_seed + epoch), so later epochs no longer reuse the first epoch permutation.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: ca08652226
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| if order_pos >= len(order): | ||
| epoch += 1 | ||
| order = make_order(epoch) | ||
| order_pos = 0 | ||
| batch_items.append(tokenized[order[order_pos]]) |
There was a problem hiding this comment.
Stop batches from spanning torch_randperm epoch boundaries
create_ordered_batches wraps to a new permutation inside the same batch (order_pos >= len(order)), so when num_epochs is used and len(dataset) % batch_size != 0, one batch can contain samples from two different epochs. This changes both sample order and batch count (e.g., 5 samples, batch size 4, 2 epochs yields 3 batches instead of the expected 4), so MLXTrainer under-trains in epoch mode and no longer mirrors CUDA/DataLoader epoch semantics for dataset_order="torch_randperm".
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
create_ordered_batches now breaks the batch at the epoch boundary: when order_pos >= len(order) it emits the partial batch and restarts at epoch+1 rather than wrapping inside one batch, so no batch spans two epochs and the batch count matches CUDA DataLoader epoch semantics.
|
Ok pretty sure resolved review feedback |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 7e0bee546f
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| if any("norm" in part for part in parts): | ||
| return False |
There was a problem hiding this comment.
Exclude ln_ normalization weights from AdamW decay*
The new HF-style decay filter only checks for path segments containing "norm", so normalization parameters named like ln_1.weight or ln_f.weight still get decayed. This repo already treats ln_* as norm parameters in _ensure_lora_frozen (see _NORM_FRAGMENTS), so models using those names will now receive unintended weight decay on norm scales, diverging from the stated “skip bias and norms” behavior and changing optimization dynamics.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
_should_apply_weight_decay uses _part_is_norm, which matches "norm", ln_*, and ln_f, so ln_1.weight / ln_f.weight style norm scales are excluded from decay as intended.
| parts = str(path).lower().split(".") | ||
| return any("norm" in part for part in parts[:-1]) |
There was a problem hiding this comment.
Include ln_ params in fp32 norm-parameter preservation*
_keep_norm_parameters_float32 claims to keep normalization parameters in fp32, but _is_norm_parameter_path only matches components containing "norm". Any normalization layer named with ln_* (which this codebase already recognizes as norm-like elsewhere) is skipped and left in lower precision, undermining the stabilization this pass is meant to provide for FT/LoRA/QLoRA training.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
The loader _is_norm_parameter_path matches "norm", ln_*, and ln_f as well, so GPT-2 / GPT-OSS style ln_* normalization parameters are kept in fp32 by _keep_norm_parameters_float32.
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
Merges 5 main-side mlx fixes (unslothai#673 zero-token CCE, unslothai#679 + unslothai#692 LoRA save metadata, unslothai#682 invalid label NaN-poisoning, unslothai#688 tool mask). All 13 conflict regions in unsloth_zoo/mlx/utils.py resolved to keep PR unslothai#684's behavior where it conflicts on semantics: - half-open `<` length mask (PR unslothai#684 fix) wins over main's inclusive `<=` - `if labels is None` branch preserved (PR unslothai#684 generality) alongside main's `_normalize_cce_label_dtype` dtype widening - `_get_image_token_ids` legacy wrapper kept alongside main's new `_normalize_cce_label_dtype` / `_normalize_numpy_cce_labels` - `_mask_label_token_ids` calls `_normalize_cce_label_dtype` first so image masking honors main's uint-widening contract - HEAD's `_expand_token_replacements` dropped; main's three-function split (`_normalize_numpy_cce_labels` + `_expand_image_token_sequences` + `_expand_token_runs`) is canonical; duplicate HEAD wrappers removed - `_collate_vlm_prompt_completion_batch` reads back the masked labels in int64 so image + attention masking survives without narrowing - prompt-completion VLM collator routes through `_apply_vlm_label_masks` after dtype normalisation so ignore_token_ids and wide invalid ids both reach runtime CCE intact - `_to_mx_vlm_batch` uses main's `_normalize_cce_label_dtype` for labels while keeping PR unslothai#684's token_type_ids / mm_token_type_ids handling - `_unsloth_*` prefix filter preserved so the new collated_position_ids flag and main's raw-input-ids carrier both get stripped 152 MLX tests pass post-merge.
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
|
Status pass on this branch relative to today's merges and the open #738, since they share surfaces:
Recommended order: land #738 first (small, validated), then rebase this onto main and resolve the trainer/utils conflicts preferring the merged fixes. |
… training patches
|
Synced this branch with main via maintainer edit (merge commit e4f4d9f) so it is reviewable against the current tree. Resolution notes:
Verification: AST parse clean on all resolved files, workflow YAML parses, and the CPU suites pass: 16/16 behavioral gates and 38/38 across @mmathew23 please double-check the trainer.py resolution matches your intent, especially the grad-clip comment block and the deduplicated setup section. |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: e4f4d9f444
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| print("Unsloth: Using standard cross-entropy loss.") | ||
|
|
||
| # Prepare data — determine total_steps first | ||
| self._prepared_batches_include_epochs = False |
There was a problem hiding this comment.
Preserve the prebuilt epoch flag before step counting
In completion-only text training (train_on_responses_only) with num_train_epochs > 1 and max_steps <= 0, _create_labeled_batches already materializes all epochs and sets trainer._prepared_batches_include_epochs = True, but this reset clears that flag before _prepare_data() returns the prebuilt self._batches. The total-step calculation then treats len(batches) as one epoch and multiplies by num_train_epochs again, so a 3-epoch run executes 9 epochs' worth of micro-batches and over-trains/reuses data unexpectedly.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Confirmed and fixed. _train_inner was zeroing _prepared_batches_include_epochs before _prepare_data returned the prebuilt batches from train_on_responses_only, and since the prebuilt path returns early it never re-derived the flag, so the step counter multiplied the already-materialized epochs again. The reset is now gated on self._batches is None, which preserves the flag for the completion-only path while normal runs still reset and re-derive it. Added a regression test pinning 18 steps for a 3-epoch run instead of 54. Fixed in c8ad7dd.
| if dataset_order == "torch_randperm": | ||
| from .utils import _torch_randperm_order | ||
| # Reseed per epoch (matches `create_ordered_batches`). | ||
| order = _torch_randperm_order(len(items), seed + epoch_idx) |
There was a problem hiding this comment.
Normalize nullable seed before labeled randperm ordering
When train_on_responses_only is used with dataset_order="torch_randperm" and a caller leaves seed=None, this addition raises TypeError before _torch_randperm_order can apply its own None default. The unlabeled ordered-batch path already normalizes nullable seeds, so completion-only text training is the path that still fails for the same valid seed configuration.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Good catch, fixed. The labeled torch_randperm path did seed + epoch_idx directly, which raised TypeError when seed was None, whereas the unlabeled path already runs _normalize_seed first. The seed is now normalized before the per-epoch add so seed=None falls back to the default and per-epoch reseeding still works. Added tests that build labeled batches with seed=None and assert each epoch is a distinct permutation. Fixed in c8ad7dd.
| sliding_window_pattern = getattr(stack, "sliding_window_pattern", None) | ||
| window_size = getattr(stack, "window_size", None) |
There was a problem hiding this comment.
Read Gemma sliding-window settings from config
For Gemma-style VLM language stacks that pass token_type_ids, the sliding-window settings live on stack.config (sliding_window_pattern / sliding_window) rather than as direct stack attributes. These getattr calls therefore resolve to None, so every layer is treated as global and uses the full token-type mask instead of the local sliding-window mask, changing Gemma3 attention semantics and training parity whenever this custom hidden-stack path is used.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Good catch on the robustness angle. On real mlx-vlm 0.6.2 both Gemma3Model and Gemma4TextModel copy window_size and sliding_window_pattern onto the module, so today the masks are correct. A config fallback was added anyway (reading sliding_window and sliding_window_pattern off stack.config) so a stack that only stores them there still builds windowed masks rather than treating every layer as global. Fixed in c8ad7dd.
|
Validated this branch end to end on real Apple Silicon via the staging fork (danielhanchen/unsloth-zoo-staging, branch staging/mlx-pr-684, macos-14). Summary: Green across the board: https://github.com/danielhanchen/unsloth-zoo-staging/actions/runs/27351187334
Two commits pushed to this branch on top of the main sync:
Staging PR with the full diff and CI history: danielhanchen#28 From my side this branch is now in good shape: synced with main, the qwen3_5 patch wiring preserved, training proven on Metal, and the full suite green. |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: fd8e742b3f
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| if image.size[0] > image_size: | ||
| width, height = image.size | ||
| new_width = (width * image_size + width // 2) // width | ||
| new_height = (height * image_size + width // 2) // width | ||
| image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) |
There was a problem hiding this comment.
Downscale tall VLM images before processor
When image_size is an int, this only checks/scales by width, so portrait images such as 512×2048 with a 512 limit are left unchanged, and taller images like 1024×4096 are resized to 512×2048. Those oversized images still reach the VLM processor and can generate far more patches than intended, causing avoidable OOMs or batch shape drift; the resize should account for the limiting dimension rather than only image.size[0].
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Confirmed. The check only looked at width, so a tall portrait image like 512x2048 under a 512 cap slipped through unscaled. The resize now scales on the larger of width and height, so both portrait and landscape images that exceed the cap shrink; landscape sizing is unchanged. Fixed in c8ad7dd.
| if labels is None: | ||
| labels = batch_dict["input_ids"].astype(mx.int32) | ||
| else: | ||
| labels = labels.astype(mx.int32) |
There was a problem hiding this comment.
Preserve wide VLM labels until CCE validation
When the collator supplies labels, this cast narrows them to int32 before _vlm_cce_forward sends them to runtime CCE. For the same unsigned/wide invalid IDs this file otherwise preserves via _RAW_INPUT_IDS_FOR_LABELS (e.g. values that wrap to -100 on int32), the invalid target is silently converted into an ignore label instead of reaching CCE's range check, so bad processor outputs can be hidden and train/eval loss becomes incorrect.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Right, the int32 cast here was undoing the wide-id preservation the rest of the file sets up via _RAW_INPUT_IDS_FOR_LABELS and _normalize_cce_label_dtype. A uint32 id near 2**32 wrapped on the narrow and silently turned into an ignore position instead of reaching CCE as an out-of-vocab sentinel. Labels are now widened instead of narrowed and derived from the raw pre-narrow carrier, while padding and ignore tokens are still masked. Fixed in c8ad7dd.
| video_token_id=int(_config_get(config, "video_token_id", _config_get(config, "video_token_index"))), | ||
| spatial_merge_size=int(vision_config.get("spatial_merge_size", 2)), | ||
| ) | ||
| batch_dict["_unsloth_collated_position_ids"] = True |
There was a problem hiding this comment.
Strip collator-only position flag before VLM forward
For Qwen/GLM batches this private marker is added to the batch, but the non-CCE VLM loss still forwards every extra key except input_ids/pixel_values/attention_mask/labels/the raw-label carrier into model(...). With use_cce=False, _unsloth_collated_position_ids=True is therefore passed as an unexpected model kwarg for these architectures, while the CCE path already filters private _unsloth_ keys; the baseline path needs the same filtering.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Confirmed. The CCE path strips every _unsloth_* carrier, but the baseline (use_cce=False) path only dropped a fixed key list, so the _unsloth_collated_position_ids marker leaked into model(...) and would TypeError on Qwen/GLM forwards that reject unknown kwargs. The baseline filter now uses the same _unsloth_ prefix strip so the two paths stay in parity; position_ids still passes through. Fixed in c8ad7dd.
| if "weight" in self: | ||
| y = y * (1.0 + self.weight.astype(mx.float32)) |
There was a problem hiding this comment.
Apply Gemma3 norm affine parameters by attribute
In the Gemma3 fp32 norm patches, testing membership with "weight" in self/"bias" in norm is not a reliable way to detect MLX module parameters; for module-style norms this can skip the affine scale/bias (or error) even though the attributes exist. Since these patches are installed on Gemma3 load, affected runs either diverge from the pretrained LayerNorm/RMSNorm math or fail during the first forward; use attribute checks like getattr(norm, "weight", None) here and in the analogous vision LayerNorm helpers.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
self and norm here are mlx nn.Module instances (Gemma3 RMSNorm and SigLIP LayerNorm from mlx-vlm), and mlx.nn.Module subclasses dict, so "weight" in self / "bias" in norm are membership tests over the registered parameters and report presence exactly. Verified on CPU: LayerNorm with affine reports both weight and bias, RMSNorm reports weight and not bias, affine=False reports neither. So the affine scale and bias are applied correctly and there is no divergence or first-forward failure.
| elif opt_name == "sgd": | ||
| optimizer = optim.SGD(learning_rate=initial_lr, weight_decay=wd) | ||
| # HF parity: manual bias/norm-aware decoupled decay. | ||
| self._manual_weight_decay = float(wd or 0.0) | ||
| optimizer = optim.SGD(learning_rate=initial_lr, weight_decay=0.0) |
There was a problem hiding this comment.
Keep SGD weight decay coupled to the gradient
When optim="sgd" and weight_decay>0, this now disables the optimizer's built-in decay and applies the AdamW-style decoupled parameter shrink before optimizer.update. SGD weight decay in the HF/PyTorch optimizer is coupled into the gradient (and therefore into momentum/Nesterov state), so this changes training dynamics for SGD runs instead of providing the stated HF parity; keep built-in SGD decay or branch manual decay by optimizer semantics.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
You are right, fixed. SGD was getting AdamW-style decoupled decay (param shrink after the step), but PyTorch SGD and MLX SGD both couple weight decay into the gradient before momentum, so this changed SGD dynamics rather than matching HF. There is now a coupled-decay path that folds wd * param into the post-clip gradient for non-bias/non-norm leaves and SGD routes to it, keeping MLX SGD built-in decay off and preserving the HF bias/norm exemption. A new test steps MLX SGD with momentum and matches torch.optim.SGD(weight_decay=wd) numerically, and another pins that the old decoupled shrink diverges. Fixed in c8ad7dd.
| denom = mx.where(absmax > 0, absmax, mx.ones_like(absmax)) | ||
| scaled = groups / denom | ||
| indices = mx.argmin(mx.abs(scaled[..., None] - codebook), axis=-1) | ||
| absmax = _bnb_nested_absmax(absmax.reshape((-1,))).reshape((-1, 1)) |
There was a problem hiding this comment.
Keep NF4 scales un-nested unless requested
For q_mode="nf4_dense", every group absmax is passed through _bnb_nested_absmax, which simulates nested/double quantization of the scale values. The allowed BitsAndBytes config path explicitly rejects bnb_4bit_use_double_quant=True, so the default/accepted NF4 case should use the raw per-group absmax; with varying group magnitudes this extra scale quantization changes the dense dequantized weights and breaks the intended CUDA NF4 parity.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Agreed. The accepted BitsAndBytesConfig path rejects bnb_4bit_use_double_quant=True, and CUDA bitsandbytes only applies the nested absmax when a nested quant state is present, so default NF4 should keep plain scales. The dense dequant was always nesting; it is now gated behind a use_double_quant flag that defaults off, so default NF4 matches the single-quant reference while the nested simulation stays reachable explicitly. Fixed in c8ad7dd.
|
Correctness verification round complete. All 28 review threads now have replies with verdicts and evidence; summary:
Validation on Apple Silicon (staging fork, macos-14): https://github.com/danielhanchen/unsloth-zoo-staging/actions/runs/27354619482 is fully green with the fixes: qwen3_5 VJP regressions, the real LoRA training e2e (loss improves, adapters save at end), and the full MLX suite breadth pass, one pytest process per file. CPU-side, the combined suites pass 148 tests plus the full per-file sweep. One review note for whoever lands this: the default grad clip changed from main's elementwise 1.0 to a per-leaf L2 norm cap of 1.0. The resolution priority is coherent (explicit max_grad_value wins, then max_grad_leaf_norm, then max_grad_norm for HF-style global norm, explicit 0 disables) and it is direction-preserving and memory-cheap, so I think it is the right default, but it is a deliberate behavior change worth knowing about. |
|
Synced with main again and resolved the conflicts in 819bc46. Two MLX changes landed on main since the last sync:
Validation on the merge commit:
PR is mergeable again. |
|
Synced with main once more in 055a7f4. Only #763 (llama.cpp CMake-only install fix) had landed since the last sync; the merge was clean with no conflicts and does not touch the MLX surface. Re-validation on the new head:
Also re-checked the diff surface against main: only No new review feedback since the last round; all 28 threads remain answered. PR is mergeable. |
…SGD, and VLM training for PR unslothai#684
… use Qwen2-VL for the VLM e2e test
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 0120a7c5af
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| lr_value = optimizer.learning_rate | ||
| if hasattr(lr_value, "astype"): | ||
| lr = lr_value.astype(mx.float32) | ||
| else: | ||
| lr = mx.array(lr_value, dtype=mx.float32) | ||
| scale = mx.array(1.0, dtype=mx.float32) - lr * mx.array(wd, dtype=mx.float32) |
There was a problem hiding this comment.
Apply per-parameter LR to decoupled decay
When lora_plus_ratio or embedding_learning_rate is used with AdamW/Muon/Lion and weight_decay > 0, gradients are scaled per leaf in _grad_leaf_scale, but this manual decoupled decay still shrinks every decayed parameter using only the optimizer's single base LR. Decoupled weight decay is multiplied by the parameter group's effective LR, so lora_b and embedding/lm_head parameters are under- or over-decayed relative to the LR mode being requested; carry the same per-leaf LR multiplier into this scale.
Useful? React with 👍 / 👎.
| if not isinstance(content, list): | ||
| continue | ||
| for part in content: | ||
| if isinstance(part, dict) and part.get("type") == "image": |
There was a problem hiding this comment.
Count image_url parts when repairing DeepSeek prompts
For DeepSeek-style processors this repair decides whether to prepend missing image tokens from _count_vlm_image_parts, but the counter only recognizes parts whose type is exactly "image". Datasets using the common VLM schema {"type": "image_url", ...} or input_image can still supply images through the item/process_vision_info paths, yet image_count stays zero, so the rendered prompt is sent to the processor with too few image placeholders and can fail or misalign images with text.
Useful? React with 👍 / 👎.
| is_global = not sliding_window_pattern or ( | ||
| i % sliding_window_pattern == sliding_window_pattern - 1 | ||
| ) | ||
| local_mask = token_type_mask if is_global else sliding_token_type_mask |
There was a problem hiding this comment.
Fall back when sliding window mask is unavailable
If a Gemma-style stack/config exposes sliding_window_pattern but does not expose window_size/sliding_window (or sets it to 0), sliding_token_type_mask remains None; every non-global layer then receives None here instead of the causal/token-type mask. In that configuration the local layers lose both causal and padding masking during CCE hidden-stack execution, so this branch should only select the sliding mask when it was actually built, otherwise fall back to token_type_mask.
Useful? React with 👍 / 👎.
|
Ran a deeper behavioral validation round on real Apple Silicon before we land this. Added
The new tests caught two real bugs, both fixed:
Validation state, all on macos-14 M1 with real Metal:
Branch remains synced with main and mergeable. |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: cd8efb9a4f
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| # Qwen3-VL first-step loss from ~6.45 to ~6.90 on the real-cat fixture. | ||
| if use_collated_position_ids and "position_ids" in extra_kwargs: | ||
| backbone_kwargs["position_ids"] = extra_kwargs["position_ids"] | ||
| if "token_type_ids" in extra_kwargs: |
There was a problem hiding this comment.
Route mm_token_type_ids into the hidden stack
When a VLM processor returns only mm_token_type_ids (the Gemma 4-style key that this collator preserves and requests), this CCE path never aliases it to the token_type_ids key consumed by _run_hidden_stack. In hidden-stack CCE runs, those batches therefore fall back to the plain causal/available mask instead of the Gemma image token-type/sliding mask, changing attention semantics for processors that emit mm_token_type_ids rather than token_type_ids.
Useful? React with 👍 / 👎.
Reworks the MLX training path to match the behavior of Unsloth's CUDA/transformers path, with a focus on making VLM fine-tuning work properly on Apple Silicon.
What changed
HF Trainer parity
max_grad_valuetakes priority, then per-leaf L2 norm (max_grad_leaf_norm), then global norm (max_grad_norm); an explicit 0 disables clipping.torch.optim.SGD.torch_randpermwith deterministic per-epoch reseeding, so a seeded MLX run sees the same batch order semantics as a CUDA DataLoader run. Epoch boundaries no longer span batches.train_on_responses_only(completion-only training) is implemented for both text and VLM models, mirroring the HF/unsloth API.VLM training
pad_token_id == eos_token_id.Tests
Behavior change to be aware of
The default gradient clipping changes from an elementwise clamp at 1.0 to a per-leaf L2 norm clip at 1.0. The new default is closer to standard norm-based clipping and is what the e2e runs validate, but it is a deliberate change from the previous MLX default.
Validation
All on macos-14 (M1, real Metal) plus a Linux CPU sweep of every suite file:
train_on_responses_onlywith mlx-lm's TokenizerWrapper (not callable, so the HF masking impl failed) and Qwen2-VL training hitting the fused MRoPE kernel with no VJP.