Skip to content

[DCP][DSD] Add a test case to demonstrate the workaround to load full state dict into a 2D model#135763

Closed
wz337 wants to merge 4 commits intogh/wz337/29/basefrom
gh/wz337/29/head
Closed

[DCP][DSD] Add a test case to demonstrate the workaround to load full state dict into a 2D model#135763
wz337 wants to merge 4 commits intogh/wz337/29/basefrom
gh/wz337/29/head

Conversation

@wz337
Copy link
Contributor

@wz337 wz337 commented Sep 11, 2024

Stack from ghstack (oldest at bottom):

cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wconstab @d4l3k @c-p-i-o

Fix #134095

This is a workaround for loading full state dict into a FSDP1+TP 2D model.
Since named_parameters() in FSDP1 does not return DTensor, we don't have the information to shard the full_state_dict and load it directly into the 2d model. In order to load a full state dict in FSDP1+TP 2D model, we need to do:

  • load the full state dict into a 1D FSDP model
  • dcp.save the full/shard state dict into storage
  • initialize a 2D FSDP1+TP model
  • get the default sharded state dict for the 2D model (full_state_dict=False)
  • dcp.load the state dict from storage
  • load the state dict into the 2D model

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 11, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/135763

Note: Links to docs will display an error until the docs builds have been completed.

❌ 4 New Failures, 4 Unrelated Failures

As of commit bd462ec with merge base 011cae9 (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

BROKEN TRUNK - The following jobs failed but was present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue topic: not user facing topic category labels Sep 11, 2024
@wz337 wz337 changed the title demonstrate how to load 2d [DCP][DSD] Add a test case to demonstrate the workaround to load full state dict into a 2D model Sep 11, 2024
@wz337 wz337 requested a review from fegin September 11, 2024 22:05
[ghstack-poisoned]
[ghstack-poisoned]
wz337 added a commit that referenced this pull request Sep 12, 2024
ghstack-source-id: bd1d283
Pull Request resolved: #135763
@wz337 wz337 added the ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR label Sep 12, 2024
Copy link
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the workaround test!

@wz337
Copy link
Contributor Author

wz337 commented Sep 12, 2024

@pytorchmergebot merge -i

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 12, 2024
@pytorchmergebot
Copy link
Collaborator

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / macos-py3-arm64 / build

Details for Dev Infra team Raised by workflow job

[ghstack-poisoned]
wz337 added a commit that referenced this pull request Sep 12, 2024
ghstack-source-id: fa14330
Pull Request resolved: #135763
@wz337
Copy link
Contributor Author

wz337 commented Sep 13, 2024

@pytorchmergebot merge -i

pytorchmergebot pushed a commit that referenced this pull request Sep 17, 2024
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
… state dict into a 2D model (pytorch#135763)

Fix pytorch#134095

This is a workaround for loading full state dict into a FSDP1+TP 2D model.
Since named_parameters() in FSDP1 does not return DTensor, we don't have the information to shard the full_state_dict and load it directly into the 2d model. In order to load a full state dict in FSDP1+TP 2D model, we need to do:
- load the full state dict into a 1D FSDP model
- dcp.save the full/shard state dict into storage
- initialize a 2D FSDP1+TP model
- get the default sharded state dict for the 2D model (full_state_dict=False)
- dcp.load the state dict from storage
- load the state dict into the 2D model
Pull Request resolved: pytorch#135763
Approved by: https://github.com/fegin
ghstack dependencies: pytorch#135725
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
@github-actions github-actions bot deleted the gh/wz337/29/head branch October 14, 2024 06:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants