Skip to content

[Cherry-pick][DSD] Fix distributed state dict full_state_dict option hang during set_state_dict (#135725) and Fix loading uneven full tensor into sharded state dict (#136365)#136903

Merged
kit1980 merged 2 commits intopytorch:release/2.5from
wz337:release/2.5
Sep 30, 2024

Conversation

@wz337
Copy link
Contributor

@wz337 wz337 commented Sep 27, 2024

…et_state_dict (pytorch#135725)

Fix pytorch#134095
This fix distributed state dict full_state_dict option hang during set_state_dict. We switch `_distribute_tensors` in _state_dict_utils.py to use `DTensor.from_local` instead of `distribute_tensor` to support FSDP2+TP 2D strided sharding use case, as `distribute_tensor` cannot handle strided sharding yet. `distribute_tensor` incurs a scatter behind the scenes, while `DTensor.from_local` takes the local slice from the full tensor on each rank to create the DTensor (no collective).  This means it's the user's responsibility to make sure the full_tensor from the full_state_dict is the same across all ranks.
Pull Request resolved: pytorch#135725
Approved by: https://github.com/fegin

(cherry picked from commit 0cdc6a8)
…#136365)

Fix pytorch#136228.

This is a follow up on pytorch#135725. We need to pass shape and stride from the original dtensor, since for uneven case, `from_local` would calculate shape and stride assuming the tensor is evenly-sharded based on the local tensor.

Pull Request resolved: pytorch#136365
Approved by: https://github.com/fegin

(cherry picked from commit 637d5c4)
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 27, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/136903

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 1 Unrelated Failure

As of commit 08ae534 with merge base b7eb725 (image):

NEW FAILURE - The following job has failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added module: distributed_checkpoint oncall: distributed Add this issue/PR to distributed oncall triage queue labels Sep 27, 2024
@wz337 wz337 added this to the 2.5.0 milestone Sep 27, 2024
@wz337 wz337 marked this pull request as ready for review September 27, 2024 22:28
@kit1980 kit1980 merged commit 70298e9 into pytorch:release/2.5 Sep 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

oncall: distributed Add this issue/PR to distributed oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants