[FSDP] Move the sharded_state_dict logic to the post hook to avoid OOM#82613
[FSDP] Move the sharded_state_dict logic to the post hook to avoid OOM#82613fegin wants to merge 2 commits intogh/fegin/20/basefrom
Conversation
The original implementation put the call of `_summon_full_params()` in `state_dict()`. However, because `state_dict()` is recursive, `_summon_full_params()` will also behave like the recursive version even if recursive is set to False. This PR put the logic in the post hook to solve the OOM issue. Differential Revision: [D38329396](https://our.internmc.facebook.com/intern/diff/D38329396/) [ghstack-poisoned]
🔗 Helpful links
❌ 1 New Failures, 5 PendingAs of commit d0f7f00 (more details on the Dr. CI page): Expand to see more
🕵️ 1 new failure recognized by patternsThe following CI failures do not appear to be due to upstream breakages
|
The original implementation put the call of `_summon_full_params()` in `state_dict()`. However, because `state_dict()` is recursive, `_summon_full_params()` will also behave like the recursive version even if recursive is set to False. This PR put the logic in the post hook to solve the OOM issue. Differential Revision: [D38329396](https://our.internmc.facebook.com/intern/diff/D38329396/) ghstack-source-id: 163196066 Pull Request resolved: #82613
rohan-varma
left a comment
There was a problem hiding this comment.
Thanks for the fix! It would be great to test it on a use case where full state dict fails but this one succeeds.
| state_dict[fqn] = init_from_local_shards( | ||
| local_shards, param.size(), process_group=self.process_group | ||
| ) # type: ignore[assignment] | ||
| state_dict.pop(f"{prefix}{FLAT_PARAM}") |
There was a problem hiding this comment.
I'm assuming that this is removing the key checkpointed by the super().state_dict() call?
| elif self._state_dict_type == StateDictType.LOCAL_STATE_DICT: | ||
| elif ( | ||
| self._state_dict_type == StateDictType.LOCAL_STATE_DICT or | ||
| self._state_dict_type == StateDictType.SHARDED_STATE_DICT |
There was a problem hiding this comment.
so it seems that sharded state dict calling state_dict is not meant to do much, we remove the checkpointed FLAT_PARAM and add new data to the state_dict for the sharded original parameters.
If this is the case, could we just remove the super().state_dict calls, recurse ourselves and call the post hook?
There was a problem hiding this comment.
We need the recursive calls for 1.) constructing the correct prefix, 2.) calling post_hook in a reversed order. We can do this ourselves but it will be better to reuse state_dict logic. The only thing that sharded_state_dict does not need is the detach one but that should not causing too many overheads.
| "not be SUMMON_FULL_PARAMS." | ||
| ) | ||
| with self._summon_full_params(recurse=False, writeback=False): | ||
| for fqn, _, _ in self._param_fqns: |
There was a problem hiding this comment.
wouldn't named_parameters() also just give the parameter names we need?
There was a problem hiding this comment.
named_parameters() will give us more parameters than what we need. We need to use recursive named_parameters() and this will give us more parameters: 1.) parameters in children FSDP modules 2.) parameters that are ignored by FSDP.
…to avoid OOM" The original implementation put the call of `_summon_full_params()` in `state_dict()`. However, because `state_dict()` is recursive, `_summon_full_params()` will also behave like the recursive version even if recursive is set to False. This PR put the logic in the post hook to solve the OOM issue. Differential Revision: [D38329396](https://our.internmc.facebook.com/intern/diff/D38329396/) [ghstack-poisoned]
Pull Request resolved: #82613 The original implementation put the call of `_summon_full_params()` in `state_dict()`. However, because `state_dict()` is recursive, `_summon_full_params()` will also behave like the recursive version even if recursive is set to False. This PR put the logic in the post hook to solve the OOM issue. ghstack-source-id: 163330033 Differential Revision: [D38329396](https://our.internmc.facebook.com/intern/diff/D38329396/)
|
@pytorchbot merge |
|
@pytorchbot successfully started a merge job. Check the current status here |
|
Hey @fegin. |
#82613) (#82613) Summary: The original implementation put the call of `_summon_full_params()` in `state_dict()`. However, because `state_dict()` is recursive, `_summon_full_params()` will also behave like the recursive version even if recursive is set to False. This PR put the logic in the post hook to solve the OOM issue. Pull Request resolved: #82613 Approved by: https://github.com/rohan-varma Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/b750c10fbe288a201e623e89473bd7ea0f485d56 Original Phabricator Test Plan: CI Reviewed By: rohan-varma Differential Revision: D38329396 Pulled By: fegin fbshipit-source-id: 2f560f9b7ba73ad515987a65a684076f605a7635
Stack from ghstack (oldest at bottom):
The original implementation put the call of
_summon_full_params()instate_dict(). However, becausestate_dict()is recursive,_summon_full_params()will also behave like the recursive version even if recursive is set to False. This PR put the logic in the post hook to solve the OOM issue.Differential Revision: D38329396