[DSD] Implement broadcast_from_rank0 option for optim state_dict#125339
[DSD] Implement broadcast_from_rank0 option for optim state_dict#125339fegin wants to merge 6 commits intogh/fegin/236/basefrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/125339
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit f3bb51d with merge base 196a0b1 ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| ) | ||
| if equal: | ||
| self.assertEqual(states, fsdp_states) | ||
| def check(equal): |
There was a problem hiding this comment.
do we need to check get_optimizer_state_dict as well? in torchtune, we call model and optimizer sd separately
There was a problem hiding this comment.
oh yes, somehow that was removed during rebasing. Sorry for the confusion, will add it back.
| ) | ||
| elif info.broadcast_from_rank0: | ||
| info.full_state_dict = False | ||
| local_state_dict = _get_optim_state_dict(model, (optim,), info) |
There was a problem hiding this comment.
are we using _get_optim_state_dict instead of torch.optim.Optimizer.state_dict since we want to align FQNs between local_state_dict and optim_state_dict ? torch.optim.Optimizer.state_dict only give us ID keys
There was a problem hiding this comment.
yes, it is easier to proceed with keys.
|
|
||
| for optim in optimizers: | ||
| optim_state_dict = _split_optim_state_dict(model, optim, state_dict, info) | ||
| _init_optim_state(optim) |
There was a problem hiding this comment.
_init_optim_state seems to update model.parameters() for Adam even though we set grad=0 ?
repro: pytest test_distributed.py P1233005758
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: periodic / win-vs2019-cuda11.8-py3 / test (default, 4, 4, windows.g5.4xlarge.nvidia.gpu) Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge -f "The failing tests are not related." |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
Summary:
This is useful if users would like to avoid CPU memory OOM when loading from a full state_dict.
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k @LucasLLC