Skip to content

[DSD] Fix loading uneven full tensor into sharded state dict#136365

Closed
wz337 wants to merge 3 commits intogh/wz337/32/basefrom
gh/wz337/32/head
Closed

[DSD] Fix loading uneven full tensor into sharded state dict#136365
wz337 wants to merge 3 commits intogh/wz337/32/basefrom
gh/wz337/32/head

Conversation

@wz337
Copy link
Contributor

@wz337 wz337 commented Sep 20, 2024

Stack from ghstack (oldest at bottom):

Fix #136228.

This is a follow up on #135725. We need to pass shape and stride from the original dtensor, since for uneven case, from_local would calculate shape and stride assuming the tensor is evenly-sharded based on the local tensor.

cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wconstab @d4l3k @c-p-i-o @LucasLLC @MeetVadakkanchery @mhorowitz @pradeepfn

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 20, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/136365

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (8 Unrelated Failures)

As of commit e33a4b3 with merge base d3647d1 (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Sep 20, 2024
@wz337 wz337 mentioned this pull request Sep 20, 2024
[ghstack-poisoned]
[ghstack-poisoned]
wz337 added a commit that referenced this pull request Sep 20, 2024
ghstack-source-id: e6b781e
Pull Request resolved: #136365

fix

ghstack-source-id: e6b781e
Pull Request resolved: #136366
@wz337 wz337 marked this pull request as draft September 20, 2024 18:27
@wz337 wz337 added the topic: not user facing topic category label Sep 20, 2024
@wz337 wz337 changed the title add shape and stride from local state [DSD] add shape and stride from local state Sep 20, 2024
@wz337 wz337 changed the title [DSD] add shape and stride from local state [DSD] Fix loading uneven full tensor into sharded state dict Sep 20, 2024
@wz337 wz337 marked this pull request as ready for review September 20, 2024 19:58
@wz337 wz337 requested a review from fegin September 20, 2024 19:59
@wz337 wz337 added the topic: bug fixes topic category label Sep 20, 2024

@with_comms
@skip_if_lt_x_gpu(2)
def test_state_dict_util_distribute_tensors(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: comment on purpose of test, expected results, etc

@wz337
Copy link
Contributor Author

wz337 commented Sep 23, 2024

@pytorchmergebot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 23, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / linux-focal-cuda12.4-py3.10-gcc9-experimental-split-build-test / test (nogpu_NO_AVX2, 1, 2, lf.linux.2xlarge)

Details for Dev Infra team Raised by workflow job

@wz337
Copy link
Contributor Author

wz337 commented Sep 23, 2024

@pytorchmergebot merge -i

BoyuanFeng pushed a commit to BoyuanFeng/pytorch that referenced this pull request Sep 25, 2024
…#136365)

Fix pytorch#136228.

This is a follow up on pytorch#135725. We need to pass shape and stride from the original dtensor, since for uneven case, `from_local` would calculate shape and stride assuming the tensor is evenly-sharded based on the local tensor.

Pull Request resolved: pytorch#136365
Approved by: https://github.com/fegin
wz337 added a commit to wz337/pytorch that referenced this pull request Sep 27, 2024
…#136365)

Fix pytorch#136228.

This is a follow up on pytorch#135725. We need to pass shape and stride from the original dtensor, since for uneven case, `from_local` would calculate shape and stride assuming the tensor is evenly-sharded based on the local tensor.

Pull Request resolved: pytorch#136365
Approved by: https://github.com/fegin

(cherry picked from commit 637d5c4)
kit1980 pushed a commit that referenced this pull request Sep 30, 2024
…hang during set_state_dict (#135725) and Fix loading uneven full tensor into sharded state dict (#136365) (#136903)

* [DSD] Fix distributed state dict full_state_dict option hang during set_state_dict (#135725)

Fix #134095
This fix distributed state dict full_state_dict option hang during set_state_dict. We switch `_distribute_tensors` in _state_dict_utils.py to use `DTensor.from_local` instead of `distribute_tensor` to support FSDP2+TP 2D strided sharding use case, as `distribute_tensor` cannot handle strided sharding yet. `distribute_tensor` incurs a scatter behind the scenes, while `DTensor.from_local` takes the local slice from the full tensor on each rank to create the DTensor (no collective).  This means it's the user's responsibility to make sure the full_tensor from the full_state_dict is the same across all ranks.
Pull Request resolved: #135725
Approved by: https://github.com/fegin

(cherry picked from commit 0cdc6a8)

* [DSD] Fix loading uneven full tensor into sharded state dict (#136365)

Fix #136228.

This is a follow up on #135725. We need to pass shape and stride from the original dtensor, since for uneven case, `from_local` would calculate shape and stride assuming the tensor is evenly-sharded based on the local tensor.

Pull Request resolved: #136365
Approved by: https://github.com/fegin

(cherry picked from commit 637d5c4)
@github-actions github-actions bot deleted the gh/wz337/32/head branch October 25, 2024 02:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue topic: bug fixes topic category topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants