[DCP][OSS] Rank local checkpointing in DCP without collectives#147758
[DCP][OSS] Rank local checkpointing in DCP without collectives#147758saumishr wants to merge 1 commit intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/147758
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 147fb0e with merge base 8eee08d ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
3521908 to
dbc79b6
Compare
dbc79b6 to
1f25f09
Compare
36ee9c7 to
eb8d1a7
Compare
eb8d1a7 to
1329b96
Compare
1329b96 to
4d5cde1
Compare
4d5cde1 to
54c8484
Compare
78dcad0 to
d3c31ff
Compare
|
This pull request was exported from Phabricator. Differential Revision: D70112642 |
d3c31ff to
f62e067
Compare
Summary: X-link: pytorch/pytorch#147758 Context: DCP metadata collectives become prohibitively expensive as the job scale grows. This PR introduces rank-local checkpointing which basically saves and loads the checkpoint without any collective. The trade off for now is the dedupe and re-sharding. Support for these would be introduced soon. Differential Revision: D70112642
Summary: Pull Request resolved: meta-pytorch#991 X-link: pytorch/pytorch#147758 Context: DCP metadata collectives become prohibitively expensive as the job scale grows. This PR introduces rank-local checkpointing which basically saves and loads the checkpoint without any collective. The trade off for now is the dedupe and re-sharding. Support for these would be introduced soon. Differential Revision: D70112642
f3eeb56 to
79f9e79
Compare
Summary: X-link: meta-pytorch/tnt#991 Context: DCP metadata collectives become prohibitively expensive as the job scale grows. This PR introduces rank-local checkpointing which basically saves and loads the checkpoint without any collective. The trade off for now is the dedupe and re-sharding. Support for these would be introduced soon. Test Plan: E2E UTs Differential Revision: D70112642
|
This pull request was exported from Phabricator. Differential Revision: D70112642 |
11 similar comments
|
This pull request was exported from Phabricator. Differential Revision: D70112642 |
|
This pull request was exported from Phabricator. Differential Revision: D70112642 |
|
This pull request was exported from Phabricator. Differential Revision: D70112642 |
|
This pull request was exported from Phabricator. Differential Revision: D70112642 |
|
This pull request was exported from Phabricator. Differential Revision: D70112642 |
|
This pull request was exported from Phabricator. Differential Revision: D70112642 |
|
This pull request was exported from Phabricator. Differential Revision: D70112642 |
|
This pull request was exported from Phabricator. Differential Revision: D70112642 |
|
This pull request was exported from Phabricator. Differential Revision: D70112642 |
|
This pull request was exported from Phabricator. Differential Revision: D70112642 |
|
This pull request was exported from Phabricator. Differential Revision: D70112642 |
meetv18
left a comment
There was a problem hiding this comment.
Overall save path LGTM, but I have some concerns on load. Happy to approve after resolved. Thanks!
| nonlocal use_collectives | ||
| nonlocal metadata | ||
|
|
||
| if "kwargs" in inspect.signature(storage_reader.read_metadata).parameters: | ||
| try: | ||
| metadata = storage_reader.read_metadata(rank=distW.rank) # noqa: F841 | ||
|
|
||
| if metadata: | ||
| use_collectives = False | ||
| logger.info( | ||
| "Rank local metadata is found. Using no rank coordination for checkpoint loading." | ||
| ) | ||
| except Exception: | ||
| logger.info( | ||
| "Rank local metadata is not found. Falling back to global metadata." | ||
| ) | ||
|
|
||
| if use_collectives: | ||
| metadata = storage_reader.read_metadata() |
There was a problem hiding this comment.
I am bit unsure about this. Consider a user's impl where every rank reads the global metadata from storage, but still has some global planning to do, for e.g. to assign read items to only one rank and then ask it to broadcast to others. This would fundamentally break that logic? Making this change non-bwc?
There was a problem hiding this comment.
I believe the backward compatibility within DCP API is for its own behavior. Currently it checks for both, if the rank local metadata is present then it assumes that the no-collective mode is on. If not, then it falls back to the global metadata. If no metadata is found then it runs into the same exception as it does today. For a user who has complex interaction model, will need to take care of the backward compatibility as well in their own storage components. The read_metadata API allows someone to specify the rank to read a rank local metadata or a global metadata. Users can use that API to customize the behavior.
There was a problem hiding this comment.
Oh yes that makes sense to me. L245 is doing exactly that, thank you for the clarification!
| # Check whether combined chunk cover the whole tensor | ||
| tensor_volume = reduce(operator.mul, value.size, 1) | ||
| if chunks_volume != tensor_volume: | ||
| if len(global_plan) > 1 and chunks_volume != tensor_volume: |
There was a problem hiding this comment.
n00b q: why do we need this?
There was a problem hiding this comment.
Tensor volume check makes sense only in the global context. When every rank is doing its own planning, this check doesn't have much value. I plan to refactor the plan validation into local and global validation and then it would become cleaner.
|
This pull request was exported from Phabricator. Differential Revision: D70112642 |
4 similar comments
|
This pull request was exported from Phabricator. Differential Revision: D70112642 |
|
This pull request was exported from Phabricator. Differential Revision: D70112642 |
|
This pull request was exported from Phabricator. Differential Revision: D70112642 |
|
This pull request was exported from Phabricator. Differential Revision: D70112642 |
…ch#147758) Summary: X-link: meta-pytorch/tnt#991 Context: DCP metadata collectives become prohibitively expensive as the job scale grows. This PR introduces rank-local checkpointing which basically saves and loads the checkpoint without any collective. The trade off for now is the dedupe and re-sharding. Support for these would be introduced soon. Test Plan: E2E UTs Save and load test with internal DCP components: https://www.internalfb.com/mlhub/pipelines/runs/mast/torchx-textray-pretrain_mlm-lv5d7qcfmnqzkd Save and load test with OSS DCP components: https://www.internalfb.com/mlhub/pipelines/runs/mast/torchx-textray-pretrain_mlm-z1vz46vkkgtcld https://www.internalfb.com/mlhub/pipelines/runs/mast/torchx-textray-pretrain_mlm-njvvbn07rv5ckd Reviewed By: meetv18 Differential Revision: D70112642
|
This pull request was exported from Phabricator. Differential Revision: D70112642 |
|
@pytorchmergebot merge |
Merge failedReason: This PR has internal changes and must be landed via Phabricator! Please try reimporting/rexporting the PR! Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge (Initiating merge automatically since Phabricator Diff has merged) |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Summary:
DCP metadata collectives become prohibitively expensive as the job scale grows. This PR introduces rank-local checkpointing which basically saves and loads the checkpoint without any collective. The trade off for now is the dedupe and re-sharding. Support for these would be introduced soon.
Differential Revision: D70112642
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @LucasLLC @pradeepfn @kwen2501 @c-p-i-o @MeetVadakkanchery @mhorowitz @ekr0