[DSD] Fix distributed state dict full_state_dict option hang during set_state_dict#135725
[DSD] Fix distributed state dict full_state_dict option hang during set_state_dict#135725wz337 wants to merge 6 commits intogh/wz337/28/basefrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/135725
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New Failures, 3 Unrelated FailuresAs of commit ee2f244 with merge base 011cae9 ( NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@pytorchmergebot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 2 jobs have failed, first few of them are: periodic / ios-build-test / build (default, 1, 1, macos-14-xlarge, SIMULATOR, arm64, 1, 0, 1), periodic / linux-focal-cuda12.1-py3.10-gcc9-experimental-split-build / test (nogpu_NO_AVX2, 1, 1, linux.2xlarge) Details for Dev Infra teamRaised by workflow job |
|
@pytorchmergebot merge -i |
Merge startedYour change will be merged while ignoring the following 7 checks: Lint / lintrunner-noclang / linux-job, pull / linux-docs / build-docs-python-false, pull / linux-focal-cuda12.1-py3.10-gcc9-sm86 / test (default, 2, 5, linux.g5.4xlarge.nvidia.gpu), periodic / ios-build-test / build (default, 1, 1, macos-14-xlarge, SIMULATOR, arm64, 1, 0, 1), periodic / linux-focal-rocm6.1-py3.8 / test (distributed, 1, 3, linux.rocm.gpu, unstable), periodic / linux-focal-rocm6.1-py3.8 / test (distributed, 2, 3, linux.rocm.gpu, unstable), periodic / linux-focal-cuda12.1-py3.10-gcc9-experimental-split-build / test (nogpu_NO_AVX2, 1, 1, linux.2xlarge) Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
@pytorchmergebot merge -i |
Merge startedYour change will be merged while ignoring the following 5 checks: Lint / lintrunner-noclang / linux-job, pull / linux-docs / build-docs-python-false, periodic / ios-build-test / build (default, 1, 1, macos-14-xlarge, SIMULATOR, arm64, 1, 0, 1), periodic / linux-focal-rocm6.1-py3.8 / test (distributed, 1, 3, linux.rocm.gpu, unstable), periodic / linux-focal-cuda12.1-py3.10-gcc9-experimental-split-build / test (nogpu_AVX512, 1, 1, linux.2xlarge) Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
… state dict into a 2D model (#135763) Fix #134095 This is a workaround for loading full state dict into a FSDP1+TP 2D model. Since named_parameters() in FSDP1 does not return DTensor, we don't have the information to shard the full_state_dict and load it directly into the 2d model. In order to load a full state dict in FSDP1+TP 2D model, we need to do: - load the full state dict into a 1D FSDP model - dcp.save the full/shard state dict into storage - initialize a 2D FSDP1+TP model - get the default sharded state dict for the 2D model (full_state_dict=False) - dcp.load the state dict from storage - load the state dict into the 2D model Pull Request resolved: #135763 Approved by: https://github.com/fegin ghstack dependencies: #135725
Pull Request resolved: #136165 Approved by: https://github.com/kwen2501 ghstack dependencies: #135725, #135763
…et_state_dict (pytorch#135725) Fix pytorch#134095 This fix distributed state dict full_state_dict option hang during set_state_dict. We switch `_distribute_tensors` in _state_dict_utils.py to use `DTensor.from_local` instead of `distribute_tensor` to support FSDP2+TP 2D strided sharding use case, as `distribute_tensor` cannot handle strided sharding yet. `distribute_tensor` incurs a scatter behind the scenes, while `DTensor.from_local` takes the local slice from the full tensor on each rank to create the DTensor (no collective). This means it's the user's responsibility to make sure the full_tensor from the full_state_dict is the same across all ranks. Pull Request resolved: pytorch#135725 Approved by: https://github.com/fegin
…during set_state_dict (pytorch#135725)" This reverts commit 83c594e. Reverted pytorch#135725 on behalf of https://github.com/ZainRizvi due to This is breaking lint. See [GH job link](https://github.com/pytorch/pytorch/actions/runs/10835983999/job/30068709508) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/83c594ebd6dfa517fdd67ae23929cc60d5fa325d) ([comment](pytorch#135725 (comment)))
…et_state_dict (pytorch#135725) Fix pytorch#134095 This fix distributed state dict full_state_dict option hang during set_state_dict. We switch `_distribute_tensors` in _state_dict_utils.py to use `DTensor.from_local` instead of `distribute_tensor` to support FSDP2+TP 2D strided sharding use case, as `distribute_tensor` cannot handle strided sharding yet. `distribute_tensor` incurs a scatter behind the scenes, while `DTensor.from_local` takes the local slice from the full tensor on each rank to create the DTensor (no collective). This means it's the user's responsibility to make sure the full_tensor from the full_state_dict is the same across all ranks. Pull Request resolved: pytorch#135725 Approved by: https://github.com/fegin
… state dict into a 2D model (pytorch#135763) Fix pytorch#134095 This is a workaround for loading full state dict into a FSDP1+TP 2D model. Since named_parameters() in FSDP1 does not return DTensor, we don't have the information to shard the full_state_dict and load it directly into the 2d model. In order to load a full state dict in FSDP1+TP 2D model, we need to do: - load the full state dict into a 1D FSDP model - dcp.save the full/shard state dict into storage - initialize a 2D FSDP1+TP model - get the default sharded state dict for the 2D model (full_state_dict=False) - dcp.load the state dict from storage - load the state dict into the 2D model Pull Request resolved: pytorch#135763 Approved by: https://github.com/fegin ghstack dependencies: pytorch#135725
Pull Request resolved: pytorch#136165 Approved by: https://github.com/kwen2501 ghstack dependencies: pytorch#135725, pytorch#135763
Fix #136228. This is a follow up on #135725. We need to pass shape and stride from the original dtensor, since for uneven case, `from_local` would calculate shape and stride assuming the tensor is evenly-sharded based on the local tensor. Pull Request resolved: #136365 Approved by: https://github.com/fegin
…#136365) Fix pytorch#136228. This is a follow up on pytorch#135725. We need to pass shape and stride from the original dtensor, since for uneven case, `from_local` would calculate shape and stride assuming the tensor is evenly-sharded based on the local tensor. Pull Request resolved: pytorch#136365 Approved by: https://github.com/fegin
…et_state_dict (pytorch#135725) Fix pytorch#134095 This fix distributed state dict full_state_dict option hang during set_state_dict. We switch `_distribute_tensors` in _state_dict_utils.py to use `DTensor.from_local` instead of `distribute_tensor` to support FSDP2+TP 2D strided sharding use case, as `distribute_tensor` cannot handle strided sharding yet. `distribute_tensor` incurs a scatter behind the scenes, while `DTensor.from_local` takes the local slice from the full tensor on each rank to create the DTensor (no collective). This means it's the user's responsibility to make sure the full_tensor from the full_state_dict is the same across all ranks. Pull Request resolved: pytorch#135725 Approved by: https://github.com/fegin (cherry picked from commit 0cdc6a8)
…#136365) Fix pytorch#136228. This is a follow up on pytorch#135725. We need to pass shape and stride from the original dtensor, since for uneven case, `from_local` would calculate shape and stride assuming the tensor is evenly-sharded based on the local tensor. Pull Request resolved: pytorch#136365 Approved by: https://github.com/fegin (cherry picked from commit 637d5c4)
…hang during set_state_dict (#135725) and Fix loading uneven full tensor into sharded state dict (#136365) (#136903) * [DSD] Fix distributed state dict full_state_dict option hang during set_state_dict (#135725) Fix #134095 This fix distributed state dict full_state_dict option hang during set_state_dict. We switch `_distribute_tensors` in _state_dict_utils.py to use `DTensor.from_local` instead of `distribute_tensor` to support FSDP2+TP 2D strided sharding use case, as `distribute_tensor` cannot handle strided sharding yet. `distribute_tensor` incurs a scatter behind the scenes, while `DTensor.from_local` takes the local slice from the full tensor on each rank to create the DTensor (no collective). This means it's the user's responsibility to make sure the full_tensor from the full_state_dict is the same across all ranks. Pull Request resolved: #135725 Approved by: https://github.com/fegin (cherry picked from commit 0cdc6a8) * [DSD] Fix loading uneven full tensor into sharded state dict (#136365) Fix #136228. This is a follow up on #135725. We need to pass shape and stride from the original dtensor, since for uneven case, `from_local` would calculate shape and stride assuming the tensor is evenly-sharded based on the local tensor. Pull Request resolved: #136365 Approved by: https://github.com/fegin (cherry picked from commit 637d5c4)
ghstack-source-id: 2f8959d Pull Request resolved: pytorch/pytorch#135725
Stack from ghstack (oldest at bottom):
cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wconstab @d4l3k @c-p-i-o @LucasLLC @MeetVadakkanchery @mhorowitz @pradeepfn
Fix #134095
This fix distributed state dict full_state_dict option hang during set_state_dict. We switch
_distribute_tensorsin _state_dict_utils.py to useDTensor.from_localinstead ofdistribute_tensorto support FSDP2+TP 2D strided sharding use case, asdistribute_tensorcannot handle strided sharding yet.distribute_tensorincurs a scatter behind the scenes, whileDTensor.from_localtakes the local slice from the full tensor on each rank to create the DTensor (no collective). This means it's the user's responsibility to make sure the full_tensor from the full_state_dict is the same across all ranks.