Titan changes to use DCP ZOC instead of titan default Async + Pinned Memory#1287
Titan changes to use DCP ZOC instead of titan default Async + Pinned Memory#1287
Conversation
|
Synced offline, we can land this PR after 1) performance is on par or better, 2) loss curve matches the training without any checkpointing. |
Summary: This diff updates DCP driver code/APIs to support Zero Overhead Checkpointing Test Plan: Test with TorchTitan on this PR: pytorch/torchtitan#1287 Differential Revision: D72391401
Summary: Pull Request resolved: #156207 This diff updates DCP driver code/APIs to support Zero Overhead Checkpointing Test Plan: Test with TorchTitan on this PR: pytorch/torchtitan#1287 Differential Revision: D72391401
fegin
left a comment
There was a problem hiding this comment.
Overall, LGTM. We need to wait for the PyTorch PR to get the testing signals here. I'm not sure if this PR will break the TorchTitan Checkpoint unittest.
torchtitan/components/checkpoint.py
Outdated
|
|
||
| self.mp = None | ||
| self.async_future = None | ||
| self.upload_future = None |
There was a problem hiding this comment.
The naming is not clear, now that we have two different futures. Can you change async_future to staging_future? upload_future is also not clear, may be save_future as some people may use local storages, which upload is a little bit confusing here.
torchtitan/components/checkpoint.py
Outdated
| self.purge_queue.put(Terminate()) | ||
| self.purge_thread.join() | ||
|
|
||
| if self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM: |
There was a problem hiding this comment.
This check is redundant. Can we just check if self.stager is None or not?
Summary: X-link: meta-pytorch/tnt#1010 Pull Request resolved: #156207 This diff updates DCP driver code/APIs to support Zero Overhead Checkpointing Test Plan: Test with TorchTitan on this PR: pytorch/torchtitan#1287 Reviewed By: diego-urgell Differential Revision: D72391401
|
@pytorchbot rebase |
Test Titan changes to use DCP ZOC instead of titan default
Loss Curve DCP + ZOC. Traing with DCP ZOC Until Step 500. Delete Last Checkpoint (since it's sync save). Then Run Load and Run Training Until Step 1000
Loss Curve TorchTitan Async + Pinned Memory Without DCP ZOC
DCP + Titan Training on LLAMA3 8B Model For 500 Steps
Delete Checkpoint at 500th Step (Since this was sync saved). Run LLAMA3 8B model training to 1000 steps
TorchTitan Training (Async+ Pinned Mem) Without DCP for 1000 steps on LLAMA3 8B Model