Skip to content

feat(mlx): implement resume_from_checkpoint for MLX training#751

Merged
danielhanchen merged 2 commits into
unslothai:mainfrom
BardiaKoopah:feat/mlx-resume-from-checkpoint
Jun 11, 2026
Merged

feat(mlx): implement resume_from_checkpoint for MLX training#751
danielhanchen merged 2 commits into
unslothai:mainfrom
BardiaKoopah:feat/mlx-resume-from-checkpoint

Conversation

@BardiaKoopah

Copy link
Copy Markdown
Contributor

Summary

Studio's Resume button silently restarted MLX runs from step 0. The MLX worker called trainer.train() with no args, and the MLX trainer didn't accept resume_from_checkpoint. The frontend already exposes a Resume action (and can_resume_run doesn't gate by hardware), so this was a real user-facing bug.

This PR adds the trainer-side support. The Studio worker change (passing the kwarg through) is a companion PR in unslothai/unsloth.

What gets saved

The checkpoint loop already wrote adapters.safetensors + adapter_config.json via save_trainable_adapters. After a successful adapter save we now also write, best-effort:

  • optimizer_state.safetensorsmx.save_safetensors(tree_flatten(optimizer.state)). For Adam/AdamW this is per-parameter m,v + scalar step and learning_rate. Round-trip is byte-equal in mlx 0.31.2.
  • trainer_state.json{global_step, train_loss_history}. JSON rather than safetensors because these are scalars/lists, not tensors.

If the new writes fail, we print a warning and keep going. The adapter save was already successful at that point; only resume is unavailable.

This matches the contract Studio's backend already expects: studio/backend/core/training/resume.py:has_resume_state() looks for trainer_state.json in either output_dir/ or output_dir/checkpoint-N/.

What gets restored

train() now accepts resume_from_checkpoint. _train_inner does three things, in order, after building the optimizer:

  1. model.load_weights(adapters.safetensors, strict=False) — the model already has LoRA wrappers (Studio's pipeline does get_peft_model before training), so we only need to load the trained adapter tensors back into them.

  2. load_optimizer_state — restores Adam moments and the optimizer step. Without this, Adam would silently restart its moment estimates and the first post-resume step would be effectively a learning-rate-scaled gradient with zero momentum.

  3. load_trainer_state — restores global_step and train_loss_history. global_step is used to fast-forward the loop counter and batch_idx; train_loss_history is restored so the UI's loss curve stays continuous across the resume boundary.

If trainer_state.json or optimizer_state.safetensors is missing on resume, we raise rather than silently restart from step 0 with a fresh optimizer. Silent restart is the worse outcome — it gives the user false confidence.

Loop fast-forward

On resume we change the loop start to begin at _resume_step * grad_accum + 1, set batch_idx = _resume_step * grad_accum, and (for streaming mode where there is no random-access batch list) consume _resume_step * grad_accum items from the generator before entering the loop. All four data paths (create_batches / iterate_training_batches / create_vlm_batches / iterate_vlm_training_batches) take seed=args.seed and produce deterministic ordering, so fast-forward is sufficient — we don't need to save iterator position to disk.

The LR scheduler is a pure function of step, so once it starts at the right value the schedule is automatically correct.

Verification

Stop-resume harness on M2 16GB, Qwen3-0.6B + unsloth/LaTeX_OCR, max_steps=10, save_steps=5, grad_accum=4:

Step Fresh loss Resumed loss
6 2.1686279773712158 2.168627977371216
7 1.6476788520812988 1.6476788520812988
8 1.4659109115600586 1.4659109115600586
9 1.4719936847686768 1.4719936847686768
10 1.4772168397903442 1.4772168397903442

Every loss after the resume boundary matches the fresh run bit for bit. LR schedule matches (2e-04, 1.6e-04, 1.2e-04, 8e-05, 4e-05). Checkpoint dir contains the expected 4 files.

Note on Studio integration

This PR alone isn't enough to unblock the Studio Resume button — Studio's worker must also pass resume_from_checkpoint through to trainer.train(). That's a 2-line change in a companion PR.

Studio's Resume button silently restarted MLX runs from step 0: the
MLX worker called trainer.train() with no args and the MLX trainer
did not accept resume_from_checkpoint. The frontend already supports
a Resume action (and can_resume_run does not gate by hardware), so
this was a real user-facing bug.

Adds the trainer side. The Studio worker side (passing the kwarg
through) is a separate PR.

## What gets saved

The checkpoint loop already wrote adapters.safetensors and
adapter_config.json via save_trainable_adapters. After a successful
adapter save we also write, best-effort:

  - optimizer_state.safetensors: mx.save_safetensors of
    tree_flatten(optimizer.state). For Adam/AdamW this is the
    per-parameter m,v plus the scalar step and learning_rate; for
    other optimizers it is whatever .state exposes. Round-trip is
    byte-equal in mlx 0.31.2.

  - trainer_state.json: { global_step, train_loss_history }. Kept
    in JSON rather than safetensors because these are scalars/lists,
    not tensors, and JSON is easier to inspect.

If the new writes fail (disk full, permission, future MLX API change)
we print a one-line warning and keep going. The adapter save was
already successful at that point, so the run is not lost; only resume
is unavailable.

This matches the contract Studio's backend already expects:
studio/backend/core/training/resume.py has_resume_state() looks for
trainer_state.json in either the output_dir or a checkpoint-N
subdirectory.

## What gets restored

train() now accepts resume_from_checkpoint and stashes it on self.
_train_inner does three things, in order, after building the
optimizer:

  1. model.load_weights(adapters.safetensors, strict=False) -- the
     model already has LoRA wrappers (Studio's pipeline does
     get_peft_model before training), so we only need to load the
     trained adapter tensors back into them.

  2. load_optimizer_state -- restores Adam moments and the optimizer
     step. Without this, Adam would silently restart its moment
     estimates and the first post-resume step would be effectively
     a learning-rate-scaled gradient with zero momentum, drastically
     different from what fresh training would have done.

  3. load_trainer_state -- restores global_step and train_loss_history.
     global_step is used to fast-forward the loop counter and
     batch_idx; train_loss_history is restored so the UI's loss curve
     stays continuous across the resume boundary.

If trainer_state.json or optimizer_state.safetensors is missing on a
resume, we raise instead of silently restarting from step 0 with a
fresh optimizer. Silent restart is the worse outcome -- it gives the
user false confidence that their resume worked.

## Loop fast-forward

The training loop is:

  for it in range(1, total_steps * grad_accum + 1):
      batch_data = batches[batch_idx % len(batches)]
      batch_idx += 1
      ...
      self._set_optimizer_lr_for_step(optimizer, it // grad_accum - 1)

On resume we change the loop start to begin at
_resume_step * grad_accum + 1, set batch_idx = _resume_step *
grad_accum, and (for streaming mode where there is no random-access
batch list) consume _resume_step * grad_accum items from the
generator before entering the loop. The four data paths
(create_batches / iterate_training_batches / create_vlm_batches /
iterate_vlm_training_batches) all take seed=args.seed and produce
deterministic ordering, so fast-forward is sufficient -- we do not
need to save the dataloader's iterator position to disk.

The LR scheduler is _set_optimizer_lr_for_step(optimizer, step) and
is a pure function of step, so once it starts at the right value the
schedule is automatically correct from there on out.

## Verification

Stop-resume harness on M2 16GB, Qwen3-0.6B + unsloth/LaTeX_OCR,
max_steps=10, save_steps=5, grad_accum=4:

  Fresh   step 6 loss: 2.1686279773712158
  Resume  step 6 loss: 2.168627977371216    (byte-equal)
  Fresh   step 7 loss: 1.6476788520812988
  Resume  step 7 loss: 1.6476788520812988   (byte-equal)
  Fresh   step 8 loss: 1.4659109115600586
  Resume  step 8 loss: 1.4659109115600586   (byte-equal)
  Fresh   step 9 loss: 1.4719936847686768
  Resume  step 9 loss: 1.4719936847686768   (byte-equal)
  Fresh   step 10 loss: 1.4772168397903442
  Resume  step 10 loss: 1.4772168397903442  (byte-equal)

Every loss after the resume boundary matches the fresh run bit for
bit. The LR schedule also matches (2e-04, 1.6e-04, 1.2e-04, 8e-05,
4e-05). Checkpoint dir contains the expected 4 files
(adapters.safetensors, adapter_config.json,
optimizer_state.safetensors, trainer_state.json).

Companion PR in unslothai/unsloth wires Studio's _run_mlx_training
to pass resume_from_checkpoint through from config to trainer.train().
@gemini-code-assist

Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 136bd3a5ad

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +1267 to +1271
save_trainer_state(
{
"global_step": current_step,
"train_loss_history": list(self._train_loss_history),
},

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve MLX RNG state in checkpoints

When training uses dropout or any other stochastic MLX op, restoring only the adapter weights, optimizer state, and loss history is not enough because the step function tracks mx.random.state as mutable state. A resumed run starts from whatever RNG state the new process has, so the next dropout masks differ from the uninterrupted run even though the optimizer step is restored; include the MLX random state in the checkpoint/resume state to make resumes continue the same training trajectory for stochastic models.

Useful? React with 👍 / 👎.

Comment on lines +1071 to +1075
# Streaming mode: fast-forward the iterator to the resume position.
# The seed is the same and create_batches/iterate_*_batches is
# deterministic, so consuming N batches gives us the same data
# ordering the killed run would have produced.
if _resume_step > 0 and batch_iter is not None:

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Seed VLM streaming before relying on fast-forward

This fast-forward assumes every streaming iterator is deterministic from args.seed, but iterate_vlm_training_batches() currently ignores its seed argument and uses np.random.permutation for finite VLM datasets without seeding. For args.streaming=True VLM resumes after a process restart, consuming _resume_step * grad_accum batches from a newly randomized iterator lands on a different data order than the killed run, so training does not actually resume from the same batch sequence.

Useful? React with 👍 / 👎.

Comment thread unsloth_zoo/mlx/trainer.py Outdated
Comment on lines +548 to +550
def train(self, resume_from_checkpoint: str | None = None):
# Stash for _train_inner. None = fresh start, a path = resume.
self._resume_from_checkpoint = resume_from_checkpoint

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Handle boolean resume_from_checkpoint requests

The new signature accepts only path strings, but Trainer.train(resume_from_checkpoint=True) is the common HF/TRL-compatible way to resume from the latest checkpoint under output_dir. With this implementation that boolean is stashed and later interpolated as True/adapters.safetensors, so callers using the expected boolean form get a missing-file failure instead of resuming; handle True by resolving the latest checkpoint or reject it before treating it as a path.

Useful? React with 👍 / 👎.

Comment on lines +1267 to +1271
save_trainer_state(
{
"global_step": current_step,
"train_loss_history": list(self._train_loss_history),
},

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Save pending loss accumulators for mid-log resumes

When save_steps is not aligned with logging_steps, a checkpoint can be written after several optimizer steps have contributed to the current logging window but before that window is logged. The resume state only stores the already-appended loss history, so after resuming from such a checkpoint the next logged loss averages only the post-resume steps instead of the full original window, making the loss curve and final average diverge for common configurations like save_steps=5, logging_steps=10.

Useful? React with 👍 / 👎.

@danielhanchen

Copy link
Copy Markdown
Member

Reviewed and validated. The restore order (adapters, optimizer state, trainer state), the hard error instead of a silent restart when state files are missing, the LR being re-derived from the loop counter, and the fast-forward arithmetic across the four batch paths all check out. Verified the optimizer and trainer state round-trips and the missing-state errors on the test shim; the MLX test files pass.

Pushed one cleanup commit on top: moved the train() docstring back above the first statement (it had become a dead string literal after the resume stash) and chained the streaming fast-forward RuntimeError with from None so the bare StopIteration does not add noise. No behavior change.

@danielhanchen danielhanchen merged commit c1d4a47 into unslothai:main Jun 11, 2026
1 of 11 checks passed
danielhanchen pushed a commit to unslothai/unsloth that referenced this pull request Jun 11, 2026
Studio's frontend exposes a Resume action and submits requests with
resume_from_checkpoint set to a previous run's output_dir. The CUDA
training paths in worker.py read this field from config and pass it to
trainer.train() (see lines 2729-2787 and 3108-3229). The MLX path
_run_mlx_training did neither: it never read config['resume_from_checkpoint']
and called trainer.train() with no args. The MLX trainer also did not
accept the kwarg, so even threading it through would have been a no-op.

With this PR + the unsloth-zoo companion PR adding the trainer-side
support (saves optimizer_state + trainer_state, accepts and applies
resume_from_checkpoint in MLXTrainer.train()), MLX Resume now works
end-to-end. Verified on M2 16GB with Qwen3-0.6B + unsloth/LaTeX_OCR:
loss at every post-resume step matches a fresh run bit for bit
(2.168627977371216 == 2.168627977371216 at step 6, etc).

Two lines: read the field near the other config.get() extractions in
_run_mlx_training, pass it as a kwarg at the trainer.train() call site.

Companion PR: unslothai/unsloth-zoo#751
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants