Skip to content

[DCP][OSS] Rank local checkpointing in DCP without collectives#147758

Closed
saumishr wants to merge 1 commit intopytorch:mainfrom
saumishr:export-D70112642
Closed

[DCP][OSS] Rank local checkpointing in DCP without collectives#147758
saumishr wants to merge 1 commit intopytorch:mainfrom
saumishr:export-D70112642

Conversation

@saumishr
Copy link
Contributor

@saumishr saumishr commented Feb 24, 2025

Summary:
DCP metadata collectives become prohibitively expensive as the job scale grows. This PR introduces rank-local checkpointing which basically saves and loads the checkpoint without any collective. The trade off for now is the dedupe and re-sharding. Support for these would be introduced soon.

Differential Revision: D70112642

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @LucasLLC @pradeepfn @kwen2501 @c-p-i-o @MeetVadakkanchery @mhorowitz @ekr0

@pytorch-bot
Copy link

pytorch-bot bot commented Feb 24, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/147758

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 147fb0e with merge base 8eee08d (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added module: distributed_checkpoint oncall: distributed Add this issue/PR to distributed oncall triage queue labels Feb 24, 2025
@meetv18 meetv18 added the oncall: distributed checkpointing Oncall label should be attached to any issues related to distributed checkpointing. label Feb 25, 2025
gkroiz added a commit to gkroiz/pytorch that referenced this pull request Mar 9, 2025
@saumishr saumishr force-pushed the export-D70112642 branch 2 times, most recently from 36ee9c7 to eb8d1a7 Compare April 2, 2025 13:56
@saumishr saumishr force-pushed the export-D70112642 branch 2 times, most recently from 78dcad0 to d3c31ff Compare April 6, 2025 17:56
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D70112642

@pytorch pytorch deleted a comment from facebook-github-bot Apr 6, 2025
saumishr added a commit to saumishr/tnt that referenced this pull request Apr 18, 2025
Summary:

X-link: pytorch/pytorch#147758


Context:
DCP metadata collectives become prohibitively expensive as the job scale grows. This PR introduces rank-local checkpointing which basically saves and loads the checkpoint without any collective. The trade off for now is the dedupe and re-sharding. Support for these would be introduced soon.

Differential Revision: D70112642
saumishr added a commit to saumishr/tnt that referenced this pull request Apr 20, 2025
Summary:
Pull Request resolved: meta-pytorch#991

X-link: pytorch/pytorch#147758

Context:
DCP metadata collectives become prohibitively expensive as the job scale grows. This PR introduces rank-local checkpointing which basically saves and loads the checkpoint without any collective. The trade off for now is the dedupe and re-sharding. Support for these would be introduced soon.

Differential Revision: D70112642
pytorch-bot bot pushed a commit that referenced this pull request Apr 24, 2025
Summary:
X-link: meta-pytorch/tnt#991



Context:
DCP metadata collectives become prohibitively expensive as the job scale grows. This PR introduces rank-local checkpointing which basically saves and loads the checkpoint without any collective. The trade off for now is the dedupe and re-sharding. Support for these would be introduced soon.

Test Plan: E2E UTs

Differential Revision: D70112642
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D70112642

11 similar comments
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D70112642

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D70112642

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D70112642

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D70112642

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D70112642

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D70112642

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D70112642

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D70112642

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D70112642

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D70112642

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D70112642

Copy link
Contributor

@meetv18 meetv18 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall save path LGTM, but I have some concerns on load. Happy to approve after resolved. Thanks!

Comment on lines +228 to +246
nonlocal use_collectives
nonlocal metadata

if "kwargs" in inspect.signature(storage_reader.read_metadata).parameters:
try:
metadata = storage_reader.read_metadata(rank=distW.rank) # noqa: F841

if metadata:
use_collectives = False
logger.info(
"Rank local metadata is found. Using no rank coordination for checkpoint loading."
)
except Exception:
logger.info(
"Rank local metadata is not found. Falling back to global metadata."
)

if use_collectives:
metadata = storage_reader.read_metadata()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am bit unsure about this. Consider a user's impl where every rank reads the global metadata from storage, but still has some global planning to do, for e.g. to assign read items to only one rank and then ask it to broadcast to others. This would fundamentally break that logic? Making this change non-bwc?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the backward compatibility within DCP API is for its own behavior. Currently it checks for both, if the rank local metadata is present then it assumes that the no-collective mode is on. If not, then it falls back to the global metadata. If no metadata is found then it runs into the same exception as it does today. For a user who has complex interaction model, will need to take care of the backward compatibility as well in their own storage components. The read_metadata API allows someone to specify the rank to read a rank local metadata or a global metadata. Users can use that API to customize the behavior.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes that makes sense to me. L245 is doing exactly that, thank you for the clarification!

# Check whether combined chunk cover the whole tensor
tensor_volume = reduce(operator.mul, value.size, 1)
if chunks_volume != tensor_volume:
if len(global_plan) > 1 and chunks_volume != tensor_volume:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n00b q: why do we need this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tensor volume check makes sense only in the global context. When every rank is doing its own planning, this check doesn't have much value. I plan to refactor the plan validation into local and global validation and then it would become cleaner.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D70112642

4 similar comments
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D70112642

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D70112642

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D70112642

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D70112642

…ch#147758)

Summary:
X-link: meta-pytorch/tnt#991



Context:
DCP metadata collectives become prohibitively expensive as the job scale grows. This PR introduces rank-local checkpointing which basically saves and loads the checkpoint without any collective. The trade off for now is the dedupe and re-sharding. Support for these would be introduced soon.

Test Plan:
E2E UTs

Save and load test with internal DCP components: https://www.internalfb.com/mlhub/pipelines/runs/mast/torchx-textray-pretrain_mlm-lv5d7qcfmnqzkd

Save and load test with OSS DCP components: https://www.internalfb.com/mlhub/pipelines/runs/mast/torchx-textray-pretrain_mlm-z1vz46vkkgtcld
https://www.internalfb.com/mlhub/pipelines/runs/mast/torchx-textray-pretrain_mlm-njvvbn07rv5ckd

Reviewed By: meetv18

Differential Revision: D70112642
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D70112642

@saumishr
Copy link
Contributor Author

@pytorchmergebot merge

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR has internal changes and must be landed via Phabricator! Please try reimporting/rexporting the PR!

Details for Dev Infra team Raised by workflow job

@facebook-github-bot
Copy link
Contributor

@pytorchbot merge

(Initiating merge automatically since Phabricator Diff has merged)

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/h100-distributed ciflow/trunk Trigger trunk jobs on your pull request fb-exported Merged oncall: distributed checkpointing Oncall label should be attached to any issues related to distributed checkpointing. oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (checkpoint)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants