Skip to content

[DCP] OSS Zero Overhead Checkpointing Implementation#156207

Closed
Saiteja64 wants to merge 1 commit intomainfrom
export-D72391401
Closed

[DCP] OSS Zero Overhead Checkpointing Implementation#156207
Saiteja64 wants to merge 1 commit intomainfrom
export-D72391401

Conversation

@Saiteja64
Copy link
Contributor

@Saiteja64 Saiteja64 commented Jun 17, 2025

Summary: This diff updates DCP driver code/APIs to support Zero Overhead Checkpointing

Test Plan: Test with TorchTitan on this PR: pytorch/torchtitan#1287

Differential Revision: D72391401

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k

@pytorch-bot
Copy link

pytorch-bot bot commented Jun 17, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/156207

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 040df63 with merge base 2eb744c (image):

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (checkpoint) labels Jun 17, 2025
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D72391401

@vadimkantorov
Copy link
Contributor

Does the async DCP also work/practical for sending checkpoints directly to s3? (I guess using https://github.com/awslabs/s3-connector-for-pytorch ? or maybe is there some native support?)

fegin
fegin previously requested changes Jun 17, 2025
Copy link
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A general question, what if users mix async_save() with save(), does this PR handle this case?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we have a unittest for this file if possible?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see #155192, this is now added torch.cuda. I think you can remove these changes as that PR has also updated state_dict_utils

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not a good idea to ask users to selectively call close(). I would suggest that users should always call close(). Also if this is a public API, the docstring should follow the template. You can check other docstring.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also wait for the last async_save inside this API as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should just leave that to the users. In general, I want to limit global training state within DCP as much as possible.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we cache the state in AsyncStager and let user manage the lifetime of async stager? this will allow user to init async stager and destroy as needed. this can be done on every checkpoint or at the end of the job as the user sees fit.

Now, close() method is out of context. It is not on any resource/obj. It is hard for users to understand what close means and why they have to call it.

@fegin
Copy link
Contributor

fegin commented Jun 17, 2025

@vadimkantorov It should work. The async_save logic and the underlying storage are decoupled.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D72391401

1 similar comment
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D72391401

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D72391401

Copy link
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, looks good. We should remove set logging level. Also, we should have at least one unittest for this feature.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will be nice that both paths return AsyncSaveResponse -- one path has only upload_future and another path has both. But I understand that this will break BC. Not sure if we can do some tricks, like making AsyncSaveResponse inherit from Future. Just a thought, not necessarily work.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah it's unfortunate but don't think there is a clean way to do this :/

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should remove this line. Users should be able to control the logging level, not the module.

@vadimkantorov
Copy link
Contributor

vadimkantorov commented Jun 18, 2025

Oh nice :) might be good to demonstrate practical recipe with s3 checkpointing in torchtitan

Also a typical issue with checkpointing is need to prune / delete existing checkpoints to save space. Does torchtitan provide any built-in policies to prune existing checkpoints? E.g. keep rare regular checkpoints + K best checkpoints + several latest checkpoints (and all this must interface with s3, as keeping locally many checkpoints of large models is not feasible...)

facebook-github-bot pushed a commit that referenced this pull request Jun 18, 2025
Summary:

This diff updates DCP driver code/APIs to support Zero Overhead Checkpointing

Test Plan: Test with TorchTitan on this PR: pytorch/torchtitan#1287

Differential Revision: D72391401
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D72391401

Copy link
Contributor

@teja-rao teja-rao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a few comments but i think i have a fundamental question: Should we support zero overhead copy in async stager implementation instead in the legacy staging path?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see #155192, this is now added torch.cuda. I think you can remove these changes as that PR has also updated state_dict_utils

_iterate_state_dict(
ret = []
for idx, v in enumerate(iter_object):
obj = _iterate_state_dict(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why change this? if it is not needed, can we revert it?

Comment on lines +148 to +163
obj = _iterate_state_dict(
value,
sharded_tensor_func,
dtensor_func,
tensor_func,
pg=pg,
device=device,
cpu_offload=cpu_offload,
companion_obj=(
companion_obj[key] if companion_obj is not None else None
),
ranks_only=ranks_only,
type_check=type_check,
non_blocking=non_blocking,
)
ret[key] = obj
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: ret[key] = _iterate_state_dict(...)

)


class _ThreadBasedAsyncCheckpointExecutor(_AsyncCheckpointExecutor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unrelated to this PR: but we should deprecate this in favor of async_process_executor.

from torch.distributed.checkpoint.metadata import Metadata
from torch.distributed.checkpoint.planner import SavePlan, SavePlanner
from torch.distributed.checkpoint.staging import AsyncStager
from torch.distributed.checkpoint.staging import AsyncStager
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

repeated, remove?

executor: _AsyncCheckpointExecutor = (
def stage_state_dict() -> Future[STATE_DICT_TYPE]:
staging_executor = ThreadPoolExecutor(max_workers=1)
if isinstance(storage_writer, AsyncStager) and not use_default_staging:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if isinstance(storage_writer, AsyncStager) and not use_default_staging:
if storage_writer is not None and isinstance(storage_writer, AsyncStager):

Comment on lines +308 to +310
use_default_staging = False
if storage_writer is None:
use_default_staging = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove? see suggestion on L321

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we cache the state in AsyncStager and let user manage the lifetime of async stager? this will allow user to init async stager and destroy as needed. this can be done on every checkpoint or at the end of the job as the user sees fit.

Now, close() method is out of context. It is not on any resource/obj. It is hard for users to understand what close means and why they have to call it.

if isinstance(storage_writer, AsyncStager) and not use_default_staging:
staging_future = staging_executor.submit(storage_writer.stage, state_dict)
else:
# provides bwc for storage_writers not implementing AsyncStager
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to handle this case? can we ask user to implement async stager if they need to use zero-copy? I think it is simpler to support and cleaner from API point of view.

if not block_on_staging:
global _CACHED_STATE_DICT
if not _CACHED_STATE_DICT:
_CACHED_STATE_DICT = _create_cpu_state_dict(state_dict, pin_memory=True, share_memory=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pin_memory and share_memory needs to be controlled options, so the user has a choice to disable them as they come with drawbacks that might not work for every model or every system.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D72391401

Copy link
Contributor

@teja-rao teja-rao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall pr looks good to me after updates. I will stamp the PR once the testing is complete and CI shows green! Thank you for making the changes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be STATE_DICT_TYPE | Future[STATE_DICT_TYPE]?

nit: i see we are converting the state_dict_type to future to pass in which is okay but i think it is more readable to just keep passing in the Union of state_dict_type and future.

Copy link
Contributor Author

@Saiteja64 Saiteja64 Jun 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense. I was originally thinking in the long term, stage would only return a future (would work for sync as well) but I don't see that happening because it would introduce breaking changes and would be hard to cleanly deprecate the old return type

Comment on lines 299 to 298
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
if async_stager is None:
if (storage_writer is None or not isinstance(storage_writer, AsyncStager)):
async_stager = DefaultStager(StagingOptions(not block_on_staging, not block_on_staging, not block_on_staging, not block_on_staging))
elif isinstance(storage_writer, AsyncStager):
# bwc with old storage_writers
async_stager = storage_writer
if async_stager is None:
if (storage_writer is not None and isinstance(storage_writer, AsyncStager)):
# bwc with old storage_writers
async_stager = storage_writer
else:
async_stager = DefaultStager(StagingOptions(not block_on_staging, not block_on_staging, not block_on_staging, not block_on_staging))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is save method still used for sync save? why not change it to support union?

Copy link
Contributor

@teja-rao teja-rao Jun 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: what do you think about this? we can eliminate the save_wrapper and add the if instance check in the save method?

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D72391401

Copy link
Contributor

@teja-rao teja-rao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sending back for updating docs and for consideration on nits.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this assert needed? mypy typechecks should catch if you arent returning a future?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without this, we introduce a linter error because async_save either returns a Tuple of staging_future/upload_future or an upload future now.

Comment on lines 25 to 26
Copy link
Contributor

@teja-rao teja-rao Jun 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clean up/update CheckpointStager? i think these are from dcp evolution work..

Comment on lines 149 to 154
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we do not want users to create a stager each time. stager caches the storages, may be this needs an update.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

throw an exception and suggest synchronizing using the future or call staging_future.result() here?

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D72391401

5 similar comments
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D72391401

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D72391401

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D72391401

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D72391401

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D72391401

Copy link
Contributor

@teja-rao teja-rao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

approving to unblock, please fix the mypy error before landing.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D72391401

11 similar comments
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D72391401

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D72391401

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D72391401

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D72391401

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D72391401

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D72391401

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D72391401

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D72391401

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D72391401

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D72391401

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D72391401

Summary:
X-link: meta-pytorch/tnt#1010


This diff updates DCP driver code/APIs to support Zero Overhead Checkpointing

Test Plan:
Test with TorchTitan on this PR: pytorch/torchtitan#1287

Add new UT

Reviewed By: diego-urgell

Differential Revision: D72391401
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D72391401

@facebook-github-bot
Copy link
Contributor

@pytorchbot merge

(Initiating merge automatically since Phabricator Diff has merged)

@pytorch-bot
Copy link

pytorch-bot bot commented Jun 28, 2025

This PR has pending changes requested. Please address the comments and update the PR before merging.

@Saiteja64
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot
Copy link

pytorch-bot bot commented Jun 29, 2025

This PR has pending changes requested. Please address the comments and update the PR before merging.

@Saiteja64
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request fb-exported Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (checkpoint)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants