fix(distill): save trained GPU weights + periodic checkpoints (PMAT-699 P0)#1856
Merged
Merged
Conversation
…99 P0)
Two P0 defects surfaced by Stage D 2026-05-20: 25h of GB10 training
produced final_loss=3.58 (real convergence) but a 200-byte empty
model.safetensors. Root cause is a defect-pair:
Defect 1: pipeline.export() serializes the `student_weights` HashMap,
which PMAT-698f's APR short-circuit returns empty (the read
side correctly handles APR; the write side then ships
nothing). The trained GPU weights in CudaStudentProvider
were never pulled back to disk.
Defect 2: zero intermediate checkpoints across the 25h run. A crash
at step 49999 would have produced identical loss to what
we observed.
Fix:
1. New trait method `StudentLogitsProvider::save_checkpoint(&mut self,
path: &Path) -> Result<()>` with a no-op default (preserves
FixtureStudent behavior).
2. `CudaStudentProvider::save_checkpoint(path)` delegates to
`trainer.save_apr(path, "albor-distilled-v2", "Qwen2ForCausalLM")`.
GPU weights → APR v2 file with full metadata.
3. `pipeline.export()` calls `self.student.save_checkpoint(<output_dir>/model.apr)`
after the metadata sidecar write. CUDA path produces a real APR;
fixture path still writes only the metadata-only safetensors.
4. Periodic checkpointing in `pipeline.train()`:
- Every APR_DISTILL_CHECKPOINT_EVERY steps (default 5000)
- Files: `<output_dir>/ckpt-step-{N:06}.apr`
- Write failures are logged but don't fail training (loss progress
is more valuable than pristine intermediate snapshots)
- Set to 0 to disable (smoke tests)
Test plan:
- [x] cargo check (both with/without cuda feature) clean
- [x] 61 distill lib tests pass (FALSIFY-APR-DISTILL-TRAIN-001/002
unchanged; fixture path semantics preserved via no-op default)
- [ ] Live gx10 re-dispatch (Phase 4 short re-train, per user direction):
verify model.apr is written with non-trivial size AND
ckpt-step-NNNNNN.apr appears every 5000 steps
Unblocks Phase 5 (HumanEval) and Phase 6 (publish), which both require
a model with real weights.
Stage D 2026-05-20 produced loss-verdict evidence (-66% over 25h) but
no usable model. Per user direction (2026-05-21): land this P0 fix,
then dispatch a 5-10K-step re-train (~5h) to produce a usable Phase 5/6
checkpoint. Future runs benefit from periodic checkpoints automatically.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
P0 — Stage D 2026-05-20 produced 200-byte empty checkpoint despite 25h compute
Two P0 defects surfaced by Stage D Phase 4 dispatch: training ran for 25h on Blackwell GB10, final_loss=3.58 (real convergence), but
student-trained.apr/model.safetensorswas a 200-byte empty placeholder. The trained weights lived only inCudaStudentProvider's GPU memory and are gone now that the process exited.Root cause — defect pair
Defect 1 — empty export:
pipeline.export()serializes thestudent_weightsHashMap. PMAT-698f's APR short-circuit (correct on the read side) returns empty maps. The write side then ships nothing.Defect 2 — zero intermediate checkpoints:
A 25h run with no periodic saves. A crash at step 49999 would have had identical outcome to what we observed.
Fix
New trait method
StudentLogitsProvider::save_checkpoint(&mut self, path: &Path) -> Result<()>with no-op default (preservesFixtureStudentbehavior).CudaStudentProvider::save_checkpoint(path)delegates totrainer.save_apr(path, "albor-distilled-v2", "Qwen2ForCausalLM"). GPU weights → APR v2 file with full metadata.pipeline.export()callsself.student.save_checkpoint(<output_dir>/model.apr)after the metadata sidecar write. CUDA path produces a real APR; fixture path still writes only the metadata-only safetensors.Periodic checkpointing in
pipeline.train():APR_DISTILL_CHECKPOINT_EVERYsteps (default 5000)<output_dir>/ckpt-step-{N:06}.aprTest plan
cargo check(both with/without cuda feature) cleanmodel.aprnon-empty ANDckpt-step-NNNNNN.aprevery 5000 stepsWhat this unblocks
Phase 5 (HumanEval) + Phase 6 (publish) — both require a model with real weights. Per user direction (2026-05-21): land this P0 fix, then dispatch a 5-10K-step re-train (~5h) to produce a usable Phase 5/6 checkpoint.
🤖 Generated with Claude Code