Fix: AsyncCheckpointIO snapshots tensors to avoid race with parameter mutation#21079
Merged
Borda merged 2 commits intoLightning-AI:masterfrom Aug 18, 2025
Merged
Conversation
… mutation Summary - Root cause: Background thread serialized live tensor references; the training thread mutated tensors after scheduling the async save, leading to mixed-step checkpoints. - Fix: Snapshot all tensors on the main thread before submitting the async save using `apply_to_collection(..., torch.Tensor, lambda t: t.detach().clone())`. Implementation - Reproduce the issue in unit test - Clone all tensors in the checkpoint payload on the caller thread to take a point-in-time snapshot. - Supports both positional and keyword `checkpoint` parameters. - Preserves non-tensor values; handles nested containers. - Continues to surface background exceptions on teardown.
Borda
approved these changes
Aug 18, 2025
Borda
added a commit
that referenced
this pull request
Aug 28, 2025
… mutation (#21079) * Fix: AsyncCheckpointIO snapshots tensors to avoid race with parameter mutation Summary - Root cause: Background thread serialized live tensor references; the training thread mutated tensors after scheduling the async save, leading to mixed-step checkpoints. - Fix: Snapshot all tensors on the main thread before submitting the async save using `apply_to_collection(..., torch.Tensor, lambda t: t.detach().clone())`. Implementation - Reproduce the issue in unit test - Clone all tensors in the checkpoint payload on the caller thread to take a point-in-time snapshot. - Supports both positional and keyword `checkpoint` parameters. - Preserves non-tensor values; handles nested containers. - Continues to surface background exceptions on teardown. * chlog --------- Co-authored-by: Jirka B <j.borovec+github@gmail.com> (cherry picked from commit 2c74bee)
lantiga
pushed a commit
that referenced
this pull request
Aug 29, 2025
… mutation (#21079) * Fix: AsyncCheckpointIO snapshots tensors to avoid race with parameter mutation Summary - Root cause: Background thread serialized live tensor references; the training thread mutated tensors after scheduling the async save, leading to mixed-step checkpoints. - Fix: Snapshot all tensors on the main thread before submitting the async save using `apply_to_collection(..., torch.Tensor, lambda t: t.detach().clone())`. Implementation - Reproduce the issue in unit test - Clone all tensors in the checkpoint payload on the caller thread to take a point-in-time snapshot. - Supports both positional and keyword `checkpoint` parameters. - Preserves non-tensor values; handles nested containers. - Continues to surface background exceptions on teardown. * chlog --------- Co-authored-by: Jirka B <j.borovec+github@gmail.com> (cherry picked from commit 2c74bee)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Fix: AsyncCheckpointIO snapshots tensors to avoid race with parameter mutation
Fixes #20953
Summary
apply_to_collection(..., torch.Tensor, lambda t: t.detach().clone()).Implementation
checkpointparameters.📚 Documentation preview 📚: https://pytorch-lightning--21079.org.readthedocs.build/en/21079/