[DCP] OSS Zero Overhead Checkpointing Implementation#156207
[DCP] OSS Zero Overhead Checkpointing Implementation#156207
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/156207
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 040df63 with merge base 2eb744c ( UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
This pull request was exported from Phabricator. Differential Revision: D72391401 |
|
Does the async DCP also work/practical for sending checkpoints directly to s3? (I guess using https://github.com/awslabs/s3-connector-for-pytorch ? or maybe is there some native support?) |
fegin
left a comment
There was a problem hiding this comment.
A general question, what if users mix async_save() with save(), does this PR handle this case?
There was a problem hiding this comment.
Should we have a unittest for this file if possible?
There was a problem hiding this comment.
see #155192, this is now added torch.cuda. I think you can remove these changes as that PR has also updated state_dict_utils
There was a problem hiding this comment.
Probably not a good idea to ask users to selectively call close(). I would suggest that users should always call close(). Also if this is a public API, the docstring should follow the template. You can check other docstring.
There was a problem hiding this comment.
Should we also wait for the last async_save inside this API as well?
There was a problem hiding this comment.
I think we should just leave that to the users. In general, I want to limit global training state within DCP as much as possible.
There was a problem hiding this comment.
Can we cache the state in AsyncStager and let user manage the lifetime of async stager? this will allow user to init async stager and destroy as needed. this can be done on every checkpoint or at the end of the job as the user sees fit.
Now, close() method is out of context. It is not on any resource/obj. It is hard for users to understand what close means and why they have to call it.
|
@vadimkantorov It should work. The async_save logic and the underlying storage are decoupled. |
ac1c1f4 to
fd4bf88
Compare
fd4bf88 to
aa0e4b3
Compare
|
This pull request was exported from Phabricator. Differential Revision: D72391401 |
1 similar comment
|
This pull request was exported from Phabricator. Differential Revision: D72391401 |
aa0e4b3 to
3bb6615
Compare
3bb6615 to
3d5128c
Compare
|
This pull request was exported from Phabricator. Differential Revision: D72391401 |
fegin
left a comment
There was a problem hiding this comment.
Overall, looks good. We should remove set logging level. Also, we should have at least one unittest for this feature.
There was a problem hiding this comment.
It will be nice that both paths return AsyncSaveResponse -- one path has only upload_future and another path has both. But I understand that this will break BC. Not sure if we can do some tricks, like making AsyncSaveResponse inherit from Future. Just a thought, not necessarily work.
There was a problem hiding this comment.
yeah it's unfortunate but don't think there is a clean way to do this :/
There was a problem hiding this comment.
We should remove this line. Users should be able to control the logging level, not the module.
|
Oh nice :) might be good to demonstrate practical recipe with s3 checkpointing in torchtitan Also a typical issue with checkpointing is need to prune / delete existing checkpoints to save space. Does torchtitan provide any built-in policies to prune existing checkpoints? E.g. keep rare regular checkpoints + K best checkpoints + several latest checkpoints (and all this must interface with s3, as keeping locally many checkpoints of large models is not feasible...) |
Summary: This diff updates DCP driver code/APIs to support Zero Overhead Checkpointing Test Plan: Test with TorchTitan on this PR: pytorch/torchtitan#1287 Differential Revision: D72391401
3d5128c to
980c3c9
Compare
|
This pull request was exported from Phabricator. Differential Revision: D72391401 |
teja-rao
left a comment
There was a problem hiding this comment.
I added a few comments but i think i have a fundamental question: Should we support zero overhead copy in async stager implementation instead in the legacy staging path?
There was a problem hiding this comment.
see #155192, this is now added torch.cuda. I think you can remove these changes as that PR has also updated state_dict_utils
| _iterate_state_dict( | ||
| ret = [] | ||
| for idx, v in enumerate(iter_object): | ||
| obj = _iterate_state_dict( |
There was a problem hiding this comment.
why change this? if it is not needed, can we revert it?
| obj = _iterate_state_dict( | ||
| value, | ||
| sharded_tensor_func, | ||
| dtensor_func, | ||
| tensor_func, | ||
| pg=pg, | ||
| device=device, | ||
| cpu_offload=cpu_offload, | ||
| companion_obj=( | ||
| companion_obj[key] if companion_obj is not None else None | ||
| ), | ||
| ranks_only=ranks_only, | ||
| type_check=type_check, | ||
| non_blocking=non_blocking, | ||
| ) | ||
| ret[key] = obj |
There was a problem hiding this comment.
nit: ret[key] = _iterate_state_dict(...)
| ) | ||
|
|
||
|
|
||
| class _ThreadBasedAsyncCheckpointExecutor(_AsyncCheckpointExecutor): |
There was a problem hiding this comment.
unrelated to this PR: but we should deprecate this in favor of async_process_executor.
| from torch.distributed.checkpoint.metadata import Metadata | ||
| from torch.distributed.checkpoint.planner import SavePlan, SavePlanner | ||
| from torch.distributed.checkpoint.staging import AsyncStager | ||
| from torch.distributed.checkpoint.staging import AsyncStager |
| executor: _AsyncCheckpointExecutor = ( | ||
| def stage_state_dict() -> Future[STATE_DICT_TYPE]: | ||
| staging_executor = ThreadPoolExecutor(max_workers=1) | ||
| if isinstance(storage_writer, AsyncStager) and not use_default_staging: |
There was a problem hiding this comment.
| if isinstance(storage_writer, AsyncStager) and not use_default_staging: | |
| if storage_writer is not None and isinstance(storage_writer, AsyncStager): |
| use_default_staging = False | ||
| if storage_writer is None: | ||
| use_default_staging = True |
There was a problem hiding this comment.
remove? see suggestion on L321
There was a problem hiding this comment.
Can we cache the state in AsyncStager and let user manage the lifetime of async stager? this will allow user to init async stager and destroy as needed. this can be done on every checkpoint or at the end of the job as the user sees fit.
Now, close() method is out of context. It is not on any resource/obj. It is hard for users to understand what close means and why they have to call it.
| if isinstance(storage_writer, AsyncStager) and not use_default_staging: | ||
| staging_future = staging_executor.submit(storage_writer.stage, state_dict) | ||
| else: | ||
| # provides bwc for storage_writers not implementing AsyncStager |
There was a problem hiding this comment.
do we need to handle this case? can we ask user to implement async stager if they need to use zero-copy? I think it is simpler to support and cleaner from API point of view.
| if not block_on_staging: | ||
| global _CACHED_STATE_DICT | ||
| if not _CACHED_STATE_DICT: | ||
| _CACHED_STATE_DICT = _create_cpu_state_dict(state_dict, pin_memory=True, share_memory=True) |
There was a problem hiding this comment.
pin_memory and share_memory needs to be controlled options, so the user has a choice to disable them as they come with drawbacks that might not work for every model or every system.
980c3c9 to
c8bc4e5
Compare
|
This pull request was exported from Phabricator. Differential Revision: D72391401 |
teja-rao
left a comment
There was a problem hiding this comment.
overall pr looks good to me after updates. I will stamp the PR once the testing is complete and CI shows green! Thank you for making the changes.
There was a problem hiding this comment.
should this be STATE_DICT_TYPE | Future[STATE_DICT_TYPE]?
nit: i see we are converting the state_dict_type to future to pass in which is okay but i think it is more readable to just keep passing in the Union of state_dict_type and future.
There was a problem hiding this comment.
That makes sense. I was originally thinking in the long term, stage would only return a future (would work for sync as well) but I don't see that happening because it would introduce breaking changes and would be hard to cleanly deprecate the old return type
There was a problem hiding this comment.
nit:
| if async_stager is None: | |
| if (storage_writer is None or not isinstance(storage_writer, AsyncStager)): | |
| async_stager = DefaultStager(StagingOptions(not block_on_staging, not block_on_staging, not block_on_staging, not block_on_staging)) | |
| elif isinstance(storage_writer, AsyncStager): | |
| # bwc with old storage_writers | |
| async_stager = storage_writer | |
| if async_stager is None: | |
| if (storage_writer is not None and isinstance(storage_writer, AsyncStager)): | |
| # bwc with old storage_writers | |
| async_stager = storage_writer | |
| else: | |
| async_stager = DefaultStager(StagingOptions(not block_on_staging, not block_on_staging, not block_on_staging, not block_on_staging)) | |
There was a problem hiding this comment.
is save method still used for sync save? why not change it to support union?
There was a problem hiding this comment.
nit: what do you think about this? we can eliminate the save_wrapper and add the if instance check in the save method?
|
This pull request was exported from Phabricator. Differential Revision: D72391401 |
teja-rao
left a comment
There was a problem hiding this comment.
sending back for updating docs and for consideration on nits.
There was a problem hiding this comment.
is this assert needed? mypy typechecks should catch if you arent returning a future?
There was a problem hiding this comment.
Without this, we introduce a linter error because async_save either returns a Tuple of staging_future/upload_future or an upload future now.
There was a problem hiding this comment.
clean up/update CheckpointStager? i think these are from dcp evolution work..
There was a problem hiding this comment.
i think we do not want users to create a stager each time. stager caches the storages, may be this needs an update.
There was a problem hiding this comment.
throw an exception and suggest synchronizing using the future or call staging_future.result() here?
|
This pull request was exported from Phabricator. Differential Revision: D72391401 |
5 similar comments
|
This pull request was exported from Phabricator. Differential Revision: D72391401 |
|
This pull request was exported from Phabricator. Differential Revision: D72391401 |
|
This pull request was exported from Phabricator. Differential Revision: D72391401 |
|
This pull request was exported from Phabricator. Differential Revision: D72391401 |
|
This pull request was exported from Phabricator. Differential Revision: D72391401 |
teja-rao
left a comment
There was a problem hiding this comment.
approving to unblock, please fix the mypy error before landing.
|
This pull request was exported from Phabricator. Differential Revision: D72391401 |
11 similar comments
|
This pull request was exported from Phabricator. Differential Revision: D72391401 |
|
This pull request was exported from Phabricator. Differential Revision: D72391401 |
|
This pull request was exported from Phabricator. Differential Revision: D72391401 |
|
This pull request was exported from Phabricator. Differential Revision: D72391401 |
|
This pull request was exported from Phabricator. Differential Revision: D72391401 |
|
This pull request was exported from Phabricator. Differential Revision: D72391401 |
|
This pull request was exported from Phabricator. Differential Revision: D72391401 |
|
This pull request was exported from Phabricator. Differential Revision: D72391401 |
|
This pull request was exported from Phabricator. Differential Revision: D72391401 |
|
This pull request was exported from Phabricator. Differential Revision: D72391401 |
|
This pull request was exported from Phabricator. Differential Revision: D72391401 |
Summary: X-link: meta-pytorch/tnt#1010 This diff updates DCP driver code/APIs to support Zero Overhead Checkpointing Test Plan: Test with TorchTitan on this PR: pytorch/torchtitan#1287 Add new UT Reviewed By: diego-urgell Differential Revision: D72391401
|
This pull request was exported from Phabricator. Differential Revision: D72391401 |
|
@pytorchbot merge (Initiating merge automatically since Phabricator Diff has merged) |
|
This PR has pending changes requested. Please address the comments and update the PR before merging. |
|
@pytorchbot merge |
|
This PR has pending changes requested. Please address the comments and update the PR before merging. |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Summary: This diff updates DCP driver code/APIs to support Zero Overhead Checkpointing
Test Plan: Test with TorchTitan on this PR: pytorch/torchtitan#1287
Differential Revision: D72391401
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k