feat(mlx): implement resume_from_checkpoint for MLX training#751
Conversation
Studio's Resume button silently restarted MLX runs from step 0: the
MLX worker called trainer.train() with no args and the MLX trainer
did not accept resume_from_checkpoint. The frontend already supports
a Resume action (and can_resume_run does not gate by hardware), so
this was a real user-facing bug.
Adds the trainer side. The Studio worker side (passing the kwarg
through) is a separate PR.
## What gets saved
The checkpoint loop already wrote adapters.safetensors and
adapter_config.json via save_trainable_adapters. After a successful
adapter save we also write, best-effort:
- optimizer_state.safetensors: mx.save_safetensors of
tree_flatten(optimizer.state). For Adam/AdamW this is the
per-parameter m,v plus the scalar step and learning_rate; for
other optimizers it is whatever .state exposes. Round-trip is
byte-equal in mlx 0.31.2.
- trainer_state.json: { global_step, train_loss_history }. Kept
in JSON rather than safetensors because these are scalars/lists,
not tensors, and JSON is easier to inspect.
If the new writes fail (disk full, permission, future MLX API change)
we print a one-line warning and keep going. The adapter save was
already successful at that point, so the run is not lost; only resume
is unavailable.
This matches the contract Studio's backend already expects:
studio/backend/core/training/resume.py has_resume_state() looks for
trainer_state.json in either the output_dir or a checkpoint-N
subdirectory.
## What gets restored
train() now accepts resume_from_checkpoint and stashes it on self.
_train_inner does three things, in order, after building the
optimizer:
1. model.load_weights(adapters.safetensors, strict=False) -- the
model already has LoRA wrappers (Studio's pipeline does
get_peft_model before training), so we only need to load the
trained adapter tensors back into them.
2. load_optimizer_state -- restores Adam moments and the optimizer
step. Without this, Adam would silently restart its moment
estimates and the first post-resume step would be effectively
a learning-rate-scaled gradient with zero momentum, drastically
different from what fresh training would have done.
3. load_trainer_state -- restores global_step and train_loss_history.
global_step is used to fast-forward the loop counter and
batch_idx; train_loss_history is restored so the UI's loss curve
stays continuous across the resume boundary.
If trainer_state.json or optimizer_state.safetensors is missing on a
resume, we raise instead of silently restarting from step 0 with a
fresh optimizer. Silent restart is the worse outcome -- it gives the
user false confidence that their resume worked.
## Loop fast-forward
The training loop is:
for it in range(1, total_steps * grad_accum + 1):
batch_data = batches[batch_idx % len(batches)]
batch_idx += 1
...
self._set_optimizer_lr_for_step(optimizer, it // grad_accum - 1)
On resume we change the loop start to begin at
_resume_step * grad_accum + 1, set batch_idx = _resume_step *
grad_accum, and (for streaming mode where there is no random-access
batch list) consume _resume_step * grad_accum items from the
generator before entering the loop. The four data paths
(create_batches / iterate_training_batches / create_vlm_batches /
iterate_vlm_training_batches) all take seed=args.seed and produce
deterministic ordering, so fast-forward is sufficient -- we do not
need to save the dataloader's iterator position to disk.
The LR scheduler is _set_optimizer_lr_for_step(optimizer, step) and
is a pure function of step, so once it starts at the right value the
schedule is automatically correct from there on out.
## Verification
Stop-resume harness on M2 16GB, Qwen3-0.6B + unsloth/LaTeX_OCR,
max_steps=10, save_steps=5, grad_accum=4:
Fresh step 6 loss: 2.1686279773712158
Resume step 6 loss: 2.168627977371216 (byte-equal)
Fresh step 7 loss: 1.6476788520812988
Resume step 7 loss: 1.6476788520812988 (byte-equal)
Fresh step 8 loss: 1.4659109115600586
Resume step 8 loss: 1.4659109115600586 (byte-equal)
Fresh step 9 loss: 1.4719936847686768
Resume step 9 loss: 1.4719936847686768 (byte-equal)
Fresh step 10 loss: 1.4772168397903442
Resume step 10 loss: 1.4772168397903442 (byte-equal)
Every loss after the resume boundary matches the fresh run bit for
bit. The LR schedule also matches (2e-04, 1.6e-04, 1.2e-04, 8e-05,
4e-05). Checkpoint dir contains the expected 4 files
(adapters.safetensors, adapter_config.json,
optimizer_state.safetensors, trainer_state.json).
Companion PR in unslothai/unsloth wires Studio's _run_mlx_training
to pass resume_from_checkpoint through from config to trainer.train().
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 136bd3a5ad
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| save_trainer_state( | ||
| { | ||
| "global_step": current_step, | ||
| "train_loss_history": list(self._train_loss_history), | ||
| }, |
There was a problem hiding this comment.
Preserve MLX RNG state in checkpoints
When training uses dropout or any other stochastic MLX op, restoring only the adapter weights, optimizer state, and loss history is not enough because the step function tracks mx.random.state as mutable state. A resumed run starts from whatever RNG state the new process has, so the next dropout masks differ from the uninterrupted run even though the optimizer step is restored; include the MLX random state in the checkpoint/resume state to make resumes continue the same training trajectory for stochastic models.
Useful? React with 👍 / 👎.
| # Streaming mode: fast-forward the iterator to the resume position. | ||
| # The seed is the same and create_batches/iterate_*_batches is | ||
| # deterministic, so consuming N batches gives us the same data | ||
| # ordering the killed run would have produced. | ||
| if _resume_step > 0 and batch_iter is not None: |
There was a problem hiding this comment.
Seed VLM streaming before relying on fast-forward
This fast-forward assumes every streaming iterator is deterministic from args.seed, but iterate_vlm_training_batches() currently ignores its seed argument and uses np.random.permutation for finite VLM datasets without seeding. For args.streaming=True VLM resumes after a process restart, consuming _resume_step * grad_accum batches from a newly randomized iterator lands on a different data order than the killed run, so training does not actually resume from the same batch sequence.
Useful? React with 👍 / 👎.
| def train(self, resume_from_checkpoint: str | None = None): | ||
| # Stash for _train_inner. None = fresh start, a path = resume. | ||
| self._resume_from_checkpoint = resume_from_checkpoint |
There was a problem hiding this comment.
Handle boolean resume_from_checkpoint requests
The new signature accepts only path strings, but Trainer.train(resume_from_checkpoint=True) is the common HF/TRL-compatible way to resume from the latest checkpoint under output_dir. With this implementation that boolean is stashed and later interpolated as True/adapters.safetensors, so callers using the expected boolean form get a missing-file failure instead of resuming; handle True by resolving the latest checkpoint or reject it before treating it as a path.
Useful? React with 👍 / 👎.
| save_trainer_state( | ||
| { | ||
| "global_step": current_step, | ||
| "train_loss_history": list(self._train_loss_history), | ||
| }, |
There was a problem hiding this comment.
Save pending loss accumulators for mid-log resumes
When save_steps is not aligned with logging_steps, a checkpoint can be written after several optimizer steps have contributed to the current logging window but before that window is logged. The resume state only stores the already-appended loss history, so after resuming from such a checkpoint the next logged loss averages only the post-resume steps instead of the full original window, making the loss curve and final average diverge for common configurations like save_steps=5, logging_steps=10.
Useful? React with 👍 / 👎.
|
Reviewed and validated. The restore order (adapters, optimizer state, trainer state), the hard error instead of a silent restart when state files are missing, the LR being re-derived from the loop counter, and the fast-forward arithmetic across the four batch paths all check out. Verified the optimizer and trainer state round-trips and the missing-state errors on the test shim; the MLX test files pass. Pushed one cleanup commit on top: moved the train() docstring back above the first statement (it had become a dead string literal after the resume stash) and chained the streaming fast-forward RuntimeError with from None so the bare StopIteration does not add noise. No behavior change. |
Studio's frontend exposes a Resume action and submits requests with resume_from_checkpoint set to a previous run's output_dir. The CUDA training paths in worker.py read this field from config and pass it to trainer.train() (see lines 2729-2787 and 3108-3229). The MLX path _run_mlx_training did neither: it never read config['resume_from_checkpoint'] and called trainer.train() with no args. The MLX trainer also did not accept the kwarg, so even threading it through would have been a no-op. With this PR + the unsloth-zoo companion PR adding the trainer-side support (saves optimizer_state + trainer_state, accepts and applies resume_from_checkpoint in MLXTrainer.train()), MLX Resume now works end-to-end. Verified on M2 16GB with Qwen3-0.6B + unsloth/LaTeX_OCR: loss at every post-resume step matches a fresh run bit for bit (2.168627977371216 == 2.168627977371216 at step 6, etc). Two lines: read the field near the other config.get() extractions in _run_mlx_training, pass it as a kwarg at the trainer.train() call site. Companion PR: unslothai/unsloth-zoo#751
Summary
Studio's Resume button silently restarted MLX runs from step 0. The MLX worker called
trainer.train()with no args, and the MLX trainer didn't acceptresume_from_checkpoint. The frontend already exposes a Resume action (andcan_resume_rundoesn't gate by hardware), so this was a real user-facing bug.This PR adds the trainer-side support. The Studio worker change (passing the kwarg through) is a companion PR in
unslothai/unsloth.What gets saved
The checkpoint loop already wrote
adapters.safetensors+adapter_config.jsonviasave_trainable_adapters. After a successful adapter save we now also write, best-effort:optimizer_state.safetensors—mx.save_safetensors(tree_flatten(optimizer.state)). For Adam/AdamW this is per-parameterm,v+ scalarstepandlearning_rate. Round-trip is byte-equal in mlx 0.31.2.trainer_state.json—{global_step, train_loss_history}. JSON rather than safetensors because these are scalars/lists, not tensors.If the new writes fail, we print a warning and keep going. The adapter save was already successful at that point; only resume is unavailable.
This matches the contract Studio's backend already expects:
studio/backend/core/training/resume.py:has_resume_state()looks fortrainer_state.jsonin eitheroutput_dir/oroutput_dir/checkpoint-N/.What gets restored
train()now acceptsresume_from_checkpoint._train_innerdoes three things, in order, after building the optimizer:model.load_weights(adapters.safetensors, strict=False)— the model already has LoRA wrappers (Studio's pipeline doesget_peft_modelbefore training), so we only need to load the trained adapter tensors back into them.load_optimizer_state— restores Adam moments and the optimizer step. Without this, Adam would silently restart its moment estimates and the first post-resume step would be effectively a learning-rate-scaled gradient with zero momentum.load_trainer_state— restoresglobal_stepandtrain_loss_history.global_stepis used to fast-forward the loop counter andbatch_idx;train_loss_historyis restored so the UI's loss curve stays continuous across the resume boundary.If
trainer_state.jsonoroptimizer_state.safetensorsis missing on resume, we raise rather than silently restart from step 0 with a fresh optimizer. Silent restart is the worse outcome — it gives the user false confidence.Loop fast-forward
On resume we change the loop start to begin at
_resume_step * grad_accum + 1, setbatch_idx = _resume_step * grad_accum, and (for streaming mode where there is no random-access batch list) consume_resume_step * grad_accumitems from the generator before entering the loop. All four data paths (create_batches/iterate_training_batches/create_vlm_batches/iterate_vlm_training_batches) takeseed=args.seedand produce deterministic ordering, so fast-forward is sufficient — we don't need to save iterator position to disk.The LR scheduler is a pure function of step, so once it starts at the right value the schedule is automatically correct.
Verification
Stop-resume harness on M2 16GB, Qwen3-0.6B + unsloth/LaTeX_OCR, max_steps=10, save_steps=5, grad_accum=4:
Every loss after the resume boundary matches the fresh run bit for bit. LR schedule matches (
2e-04, 1.6e-04, 1.2e-04, 8e-05, 4e-05). Checkpoint dir contains the expected 4 files.Note on Studio integration
This PR alone isn't enough to unblock the Studio Resume button — Studio's worker must also pass
resume_from_checkpointthrough totrainer.train(). That's a 2-line change in a companion PR.