Skip to content

MLX trainer: max_grad_value default silently overrides max_grad_norm, breaks HF parity #662

@danielhanchen

Description

@danielhanchen

Summary

After PR #634 the MLX trainer sets MLXTrainingConfig.max_grad_value = 5.0 (originally 1.0 in the same PR) and at training-config-resolution time silently zeroes out a user-supplied max_grad_norm when both are non-zero. This breaks HuggingFace/TRL parity for the MLX path. A fine-tune that converges and emits a sensible greedy completion under transformers.SFTTrainer on CUDA produces gibberish on MLX given identical hyperparameters.

Repro

Identical 7-step LoRA on unsloth/gemma-3-270m-it, train row = \"<<HELLO!!>> My name is Unsloth!\", bs=2, grad_accum=3, lr=1e-3, lr_scheduler=constant, warmup=0, optim=adamw, weight_decay=0, max_seq=64, seed=3407, LoRA r=8 on q/k/v/o.

Run clip mode step 1 → 7 loss greedy completion of \"<<HELLO!!>> My name is \"
CUDA (torch+TRL) max_grad_norm=1.0 only 7.64 → 1.16 \" 1! ... My name is Unsloth! ...\" (contains "Unsloth")
CUDA (torch+TRL) elementwise clip_grad_value_(1.0) only 7.64 → 1.19 \" 1! What are you doing?! ...\" (no "Unsloth")
MLX (post-#634) default (max_grad_value=1.0/5.0 overrides max_grad_norm) 10.55 → 0.10 '5 lbs!'
MLX (post-#634) user sets max_grad_value=0 so max_grad_norm=1.0 wins 10.55 → 0.17 '5 lbs!'
MLX (pre-#634, last green @ unsloth 12295c1f) trainer default 10.55 → 5.04 (non-monotone) \" Unsloth!\\n\\nMy name is Unsloth! ...\" (contains "Unsloth")

The CUDA mirror script lives at temp/torchcodec_test/cuda_mirror.py in my local workspace; results JSON in the same dir. MLX numbers come from the MLX CI on Mac M1 workflow on unslothai/unsloth:

Bisection

Only one unsloth-zoo commit landed between MLX-CI last green (2026-05-14T10:52Z, unsloth 12295c1f) and first red (2026-05-14T12:24Z, unsloth a9322946): #634, e6d8f7f, 2026-05-14T12:10:03Z.

What changed in #634 that broke parity

  1. MLXTrainingConfig.max_grad_value introduced and defaulted to a non-zero value (1.0, later 5.0 in the same PR).
  2. unsloth_zoo/mlx/trainer.py:733-738: when both max_grad_norm > 0 and max_grad_value > 0, max_grad_norm is forced to 0 with a printed notice. Users passing the HF/TRL-standard max_grad_norm=1.0 get it silently dropped.
  3. bias_correction=True was correctly added to match torch.optim.AdamW. That part is HF parity and should stay.

The elementwise cap rotates the gradient direction per leaf, which is mathematically different from clip_grad_norm and is not what HF/TRL users opt into when they set max_grad_norm. The CUDA mirror above shows the same direction-rotation effect under torch with clip_grad_value_(1.0) only: identical loss curve, broken completion.

Recommended fix

In unsloth_zoo/mlx/trainer.py:

  1. MLXTrainingConfig.max_grad_value: float | None = None (off by default).
  2. At resolution time (lines 727-739) treat None as "feature disabled":
    • None_clip_grad_value = False, never override max_grad_norm.
    • 0 → same as None (off).
    • explicit float > 0 → opt-in, only then warn-and-override if max_grad_norm is also set.
  3. Leave bias_correction=True (PyTorch parity).

Effect: a default MLXTrainingConfig honors args.max_grad_norm and matches CUDA HF/TRL semantics. Power users can still opt into elementwise clipping by passing max_grad_value explicitly.

Why this matters

Unsloth's MLX path is sold as a drop-in for SFTTrainer on Apple Silicon. Today, a user fine-tuning identical config on CUDA and MLX gets different gradient-clipping semantics and visibly different convergence basins. The MLX side prints a Unsloth: max_grad_norm and max_grad_value are both enabled; ignoring max_grad_norm in favor of max_grad_value. line, but that line is one of dozens in training logs and is easy to miss.

I will open a follow-up PR with the change above; filing this first so the rationale is captured separately from the patch.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions