Skip to content

[DSD] Fix distributed state dict full_state_dict option hang during set_state_dict#135725

Closed
wz337 wants to merge 6 commits intogh/wz337/28/basefrom
gh/wz337/28/head
Closed

[DSD] Fix distributed state dict full_state_dict option hang during set_state_dict#135725
wz337 wants to merge 6 commits intogh/wz337/28/basefrom
gh/wz337/28/head

Conversation

@wz337
Copy link
Contributor

@wz337 wz337 commented Sep 11, 2024

Stack from ghstack (oldest at bottom):

cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wconstab @d4l3k @c-p-i-o @LucasLLC @MeetVadakkanchery @mhorowitz @pradeepfn

Fix #134095
This fix distributed state dict full_state_dict option hang during set_state_dict. We switch _distribute_tensors in _state_dict_utils.py to use DTensor.from_local instead of distribute_tensor to support FSDP2+TP 2D strided sharding use case, as distribute_tensor cannot handle strided sharding yet. distribute_tensor incurs a scatter behind the scenes, while DTensor.from_local takes the local slice from the full tensor on each rank to create the DTensor (no collective). This means it's the user's responsibility to make sure the full_tensor from the full_state_dict is the same across all ranks.

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 11, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/135725

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures, 3 Unrelated Failures

As of commit ee2f244 with merge base 011cae9 (image):

NEW FAILURES - The following jobs have failed:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Sep 11, 2024
[ghstack-poisoned]
@wz337 wz337 changed the title fix [DSD] Fix distributed state dict full_state_dict option hang during set_state_dict Sep 11, 2024
@wz337 wz337 requested a review from fegin September 11, 2024 19:32
@wz337 wz337 added module: distributed_checkpoint release notes: distributed (checkpoint) ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR labels Sep 11, 2024
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
Copy link
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice fix, thanks!

@wz337
Copy link
Contributor Author

wz337 commented Sep 12, 2024

@pytorchmergebot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 12, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

@wz337
Copy link
Contributor Author

wz337 commented Sep 12, 2024

@pytorchmergebot merge -i

[ghstack-poisoned]
@wz337
Copy link
Contributor Author

wz337 commented Sep 13, 2024

@pytorchmergebot merge -i

@pytorchmergebot
Copy link
Collaborator

pytorchmergebot pushed a commit that referenced this pull request Sep 13, 2024
… state dict into a 2D model (#135763)

Fix #134095

This is a workaround for loading full state dict into a FSDP1+TP 2D model.
Since named_parameters() in FSDP1 does not return DTensor, we don't have the information to shard the full_state_dict and load it directly into the 2d model. In order to load a full state dict in FSDP1+TP 2D model, we need to do:
- load the full state dict into a 1D FSDP model
- dcp.save the full/shard state dict into storage
- initialize a 2D FSDP1+TP model
- get the default sharded state dict for the 2D model (full_state_dict=False)
- dcp.load the state dict from storage
- load the state dict into the 2D model
Pull Request resolved: #135763
Approved by: https://github.com/fegin
ghstack dependencies: #135725
pytorchmergebot pushed a commit that referenced this pull request Sep 17, 2024
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
…et_state_dict (pytorch#135725)

Fix pytorch#134095
This fix distributed state dict full_state_dict option hang during set_state_dict. We switch `_distribute_tensors` in _state_dict_utils.py to use `DTensor.from_local` instead of `distribute_tensor` to support FSDP2+TP 2D strided sharding use case, as `distribute_tensor` cannot handle strided sharding yet. `distribute_tensor` incurs a scatter behind the scenes, while `DTensor.from_local` takes the local slice from the full tensor on each rank to create the DTensor (no collective).  This means it's the user's responsibility to make sure the full_tensor from the full_state_dict is the same across all ranks.
Pull Request resolved: pytorch#135725
Approved by: https://github.com/fegin
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
…et_state_dict (pytorch#135725)

Fix pytorch#134095
This fix distributed state dict full_state_dict option hang during set_state_dict. We switch `_distribute_tensors` in _state_dict_utils.py to use `DTensor.from_local` instead of `distribute_tensor` to support FSDP2+TP 2D strided sharding use case, as `distribute_tensor` cannot handle strided sharding yet. `distribute_tensor` incurs a scatter behind the scenes, while `DTensor.from_local` takes the local slice from the full tensor on each rank to create the DTensor (no collective).  This means it's the user's responsibility to make sure the full_tensor from the full_state_dict is the same across all ranks.
Pull Request resolved: pytorch#135725
Approved by: https://github.com/fegin
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
… state dict into a 2D model (pytorch#135763)

Fix pytorch#134095

This is a workaround for loading full state dict into a FSDP1+TP 2D model.
Since named_parameters() in FSDP1 does not return DTensor, we don't have the information to shard the full_state_dict and load it directly into the 2d model. In order to load a full state dict in FSDP1+TP 2D model, we need to do:
- load the full state dict into a 1D FSDP model
- dcp.save the full/shard state dict into storage
- initialize a 2D FSDP1+TP model
- get the default sharded state dict for the 2D model (full_state_dict=False)
- dcp.load the state dict from storage
- load the state dict into the 2D model
Pull Request resolved: pytorch#135763
Approved by: https://github.com/fegin
ghstack dependencies: pytorch#135725
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
pytorchmergebot pushed a commit that referenced this pull request Sep 23, 2024
Fix #136228.

This is a follow up on #135725. We need to pass shape and stride from the original dtensor, since for uneven case, `from_local` would calculate shape and stride assuming the tensor is evenly-sharded based on the local tensor.

Pull Request resolved: #136365
Approved by: https://github.com/fegin
BoyuanFeng pushed a commit to BoyuanFeng/pytorch that referenced this pull request Sep 25, 2024
…#136365)

Fix pytorch#136228.

This is a follow up on pytorch#135725. We need to pass shape and stride from the original dtensor, since for uneven case, `from_local` would calculate shape and stride assuming the tensor is evenly-sharded based on the local tensor.

Pull Request resolved: pytorch#136365
Approved by: https://github.com/fegin
wz337 added a commit to wz337/pytorch that referenced this pull request Sep 27, 2024
…et_state_dict (pytorch#135725)

Fix pytorch#134095
This fix distributed state dict full_state_dict option hang during set_state_dict. We switch `_distribute_tensors` in _state_dict_utils.py to use `DTensor.from_local` instead of `distribute_tensor` to support FSDP2+TP 2D strided sharding use case, as `distribute_tensor` cannot handle strided sharding yet. `distribute_tensor` incurs a scatter behind the scenes, while `DTensor.from_local` takes the local slice from the full tensor on each rank to create the DTensor (no collective).  This means it's the user's responsibility to make sure the full_tensor from the full_state_dict is the same across all ranks.
Pull Request resolved: pytorch#135725
Approved by: https://github.com/fegin

(cherry picked from commit 0cdc6a8)
wz337 added a commit to wz337/pytorch that referenced this pull request Sep 27, 2024
…#136365)

Fix pytorch#136228.

This is a follow up on pytorch#135725. We need to pass shape and stride from the original dtensor, since for uneven case, `from_local` would calculate shape and stride assuming the tensor is evenly-sharded based on the local tensor.

Pull Request resolved: pytorch#136365
Approved by: https://github.com/fegin

(cherry picked from commit 637d5c4)
kit1980 pushed a commit that referenced this pull request Sep 30, 2024
…hang during set_state_dict (#135725) and Fix loading uneven full tensor into sharded state dict (#136365) (#136903)

* [DSD] Fix distributed state dict full_state_dict option hang during set_state_dict (#135725)

Fix #134095
This fix distributed state dict full_state_dict option hang during set_state_dict. We switch `_distribute_tensors` in _state_dict_utils.py to use `DTensor.from_local` instead of `distribute_tensor` to support FSDP2+TP 2D strided sharding use case, as `distribute_tensor` cannot handle strided sharding yet. `distribute_tensor` incurs a scatter behind the scenes, while `DTensor.from_local` takes the local slice from the full tensor on each rank to create the DTensor (no collective).  This means it's the user's responsibility to make sure the full_tensor from the full_state_dict is the same across all ranks.
Pull Request resolved: #135725
Approved by: https://github.com/fegin

(cherry picked from commit 0cdc6a8)

* [DSD] Fix loading uneven full tensor into sharded state dict (#136365)

Fix #136228.

This is a follow up on #135725. We need to pass shape and stride from the original dtensor, since for uneven case, `from_local` would calculate shape and stride assuming the tensor is evenly-sharded based on the local tensor.

Pull Request resolved: #136365
Approved by: https://github.com/fegin

(cherry picked from commit 637d5c4)
@github-actions github-actions bot deleted the gh/wz337/28/head branch October 14, 2024 06:23
KnAwnime pushed a commit to KnAwnime/Biblioteka that referenced this pull request Oct 16, 2024
ghstack-source-id: 2f8959d
Pull Request resolved: pytorch/pytorch#135725
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (checkpoint) Reverted

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants