[DCP][DSD] Add a test case to demonstrate the workaround to load full state dict into a 2D model#135763
[DCP][DSD] Add a test case to demonstrate the workaround to load full state dict into a 2D model#135763wz337 wants to merge 4 commits intogh/wz337/29/basefrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/135763
Note: Links to docs will display an error until the docs builds have been completed. ❌ 4 New Failures, 4 Unrelated FailuresAs of commit bd462ec with merge base 011cae9 ( NEW FAILURES - The following jobs have failed:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
BROKEN TRUNK - The following jobs failed but was present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
fegin
left a comment
There was a problem hiding this comment.
Thanks for adding the workaround test!
|
@pytorchmergebot merge -i |
Merge startedYour change will be merged while ignoring the following 5 checks: Lint / lintrunner-noclang / linux-job, pull / linux-docs / build-docs-python-false, periodic / ios-build-test / build (default, 1, 1, macos-14-xlarge, SIMULATOR, arm64, 1, 0, 1), periodic / linux-focal-rocm6.1-py3.8 / test (distributed, 1, 3, linux.rocm.gpu, unstable), periodic / linux-focal-rocm6.1-py3.8 / test (distributed, 2, 3, linux.rocm.gpu, unstable) Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: trunk / macos-py3-arm64 / build Details for Dev Infra teamRaised by workflow job |
|
@pytorchmergebot merge -i |
Merge startedYour change will be merged while ignoring the following 8 checks: pull / linux-docs / build-docs-python-false, pull / linux-focal-cuda12.1-py3.10-gcc9-sm86 / test (default, 2, 5, linux.g5.4xlarge.nvidia.gpu), Lint / lintrunner-noclang / linux-job, periodic / ios-build-test / build (default, 1, 1, macos-14-xlarge, SIMULATOR, arm64, 1, 0, 1), periodic / linux-focal-cuda12.1-py3.10-gcc9 / test (nogpu_NO_AVX2, 1, 1, linux.2xlarge), periodic / linux-focal-rocm6.1-py3.8 / test (distributed, 1, 3, linux.rocm.gpu, unstable), periodic / linux-focal-cuda12.1-py3.10-gcc9-experimental-split-build / test (nogpu_NO_AVX2, 1, 1, linux.2xlarge), trunk / win-vs2019-cpu-py3 / test (default, 1, 3, windows.4xlarge.nonephemeral) Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Pull Request resolved: #136165 Approved by: https://github.com/kwen2501 ghstack dependencies: #135725, #135763
… state dict into a 2D model (pytorch#135763) Fix pytorch#134095 This is a workaround for loading full state dict into a FSDP1+TP 2D model. Since named_parameters() in FSDP1 does not return DTensor, we don't have the information to shard the full_state_dict and load it directly into the 2d model. In order to load a full state dict in FSDP1+TP 2D model, we need to do: - load the full state dict into a 1D FSDP model - dcp.save the full/shard state dict into storage - initialize a 2D FSDP1+TP model - get the default sharded state dict for the 2D model (full_state_dict=False) - dcp.load the state dict from storage - load the state dict into the 2D model Pull Request resolved: pytorch#135763 Approved by: https://github.com/fegin ghstack dependencies: pytorch#135725
Pull Request resolved: pytorch#136165 Approved by: https://github.com/kwen2501 ghstack dependencies: pytorch#135725, pytorch#135763
Stack from ghstack (oldest at bottom):
cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wconstab @d4l3k @c-p-i-o
Fix #134095
This is a workaround for loading full state dict into a FSDP1+TP 2D model.
Since named_parameters() in FSDP1 does not return DTensor, we don't have the information to shard the full_state_dict and load it directly into the 2d model. In order to load a full state dict in FSDP1+TP 2D model, we need to do: