Skip to content

[FSDP] Move the sharded_state_dict logic to the post hook to avoid OOM#82613

Closed
fegin wants to merge 2 commits intogh/fegin/20/basefrom
gh/fegin/20/head
Closed

[FSDP] Move the sharded_state_dict logic to the post hook to avoid OOM#82613
fegin wants to merge 2 commits intogh/fegin/20/basefrom
gh/fegin/20/head

Conversation

@fegin
Copy link
Contributor

@fegin fegin commented Aug 1, 2022

Stack from ghstack (oldest at bottom):

The original implementation put the call of _summon_full_params() in state_dict(). However, because state_dict() is recursive, _summon_full_params() will also behave like the recursive version even if recursive is set to False. This PR put the logic in the post hook to solve the OOM issue.

Differential Revision: D38329396

The original implementation put the call of `_summon_full_params()` in `state_dict()`. However, because `state_dict()` is recursive, `_summon_full_params()` will also behave like the recursive version even if recursive is set to False. This PR put the logic in the post hook to solve the OOM issue.

Differential Revision: [D38329396](https://our.internmc.facebook.com/intern/diff/D38329396/)

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Aug 1, 2022

🔗 Helpful links

❌ 1 New Failures, 5 Pending

As of commit d0f7f00 (more details on the Dr. CI page):

Expand to see more
  • 1/1 failures introduced in this PR

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages

See GitHub Actions build pull / linux-bionic-cuda11.6-py3.10-gcc7 / test (default, 1, 4, linux.4xlarge.nvidia.gpu) (1/1)

Step: "Test" (full log | diagnosis details)

2022-08-02T23:24:14.1290800Z RuntimeError: CUDA error: an illegal memory access was encountered
2022-08-02T23:24:14.1286781Z   File "/opt/conda/lib/python3.10/site-packages/torch/testing/_internal/common_utils.py", line 1266, in set_rng_seed
2022-08-02T23:24:14.1287168Z     torch.manual_seed(seed)
2022-08-02T23:24:14.1287626Z   File "/opt/conda/lib/python3.10/site-packages/torch/random.py", line 40, in manual_seed
2022-08-02T23:24:14.1288001Z     torch.cuda.manual_seed_all(seed)
2022-08-02T23:24:14.1288479Z   File "/opt/conda/lib/python3.10/site-packages/torch/cuda/random.py", line 113, in manual_seed_all
2022-08-02T23:24:14.1288848Z     _lazy_call(cb, seed_all=True)
2022-08-02T23:24:14.1289321Z   File "/opt/conda/lib/python3.10/site-packages/torch/cuda/__init__.py", line 156, in _lazy_call
2022-08-02T23:24:14.1289650Z     callable()
2022-08-02T23:24:14.1290078Z   File "/opt/conda/lib/python3.10/site-packages/torch/cuda/random.py", line 111, in cb
2022-08-02T23:24:14.1290455Z     default_generator.manual_seed(seed)
2022-08-02T23:24:14.1290800Z RuntimeError: CUDA error: an illegal memory access was encountered
2022-08-02T23:24:14.1291273Z CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
2022-08-02T23:24:14.1291726Z For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
2022-08-02T23:24:14.1291942Z 
2022-08-02T23:24:14.1292773Z ----------------------------------------------------------------------
2022-08-02T23:24:14.1293098Z Ran 20054 tests in 4855.161s
2022-08-02T23:24:14.1293266Z 
2022-08-02T23:24:14.1293433Z FAILED (errors=1, skipped=3552, expected failures=246)
2022-08-02T23:24:14.1293640Z 
2022-08-02T23:24:14.1293763Z Generating XML reports...
2022-08-02T23:24:16.2159459Z Generated XML report: test-reports/python-unittest/test_ops/TEST-TestCommonCUDA-20220802220318.xml

This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@facebook-github-bot facebook-github-bot added cla signed oncall: distributed Add this issue/PR to distributed oncall triage queue labels Aug 1, 2022
fegin added a commit that referenced this pull request Aug 1, 2022
The original implementation put the call of `_summon_full_params()` in `state_dict()`. However, because `state_dict()` is recursive, `_summon_full_params()` will also behave like the recursive version even if recursive is set to False. This PR put the logic in the post hook to solve the OOM issue.

Differential Revision: [D38329396](https://our.internmc.facebook.com/intern/diff/D38329396/)

ghstack-source-id: 163196066
Pull Request resolved: #82613
Copy link
Contributor

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix! It would be great to test it on a use case where full state dict fails but this one succeeds.

state_dict[fqn] = init_from_local_shards(
local_shards, param.size(), process_group=self.process_group
) # type: ignore[assignment]
state_dict.pop(f"{prefix}{FLAT_PARAM}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm assuming that this is removing the key checkpointed by the super().state_dict() call?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

elif self._state_dict_type == StateDictType.LOCAL_STATE_DICT:
elif (
self._state_dict_type == StateDictType.LOCAL_STATE_DICT or
self._state_dict_type == StateDictType.SHARDED_STATE_DICT
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so it seems that sharded state dict calling state_dict is not meant to do much, we remove the checkpointed FLAT_PARAM and add new data to the state_dict for the sharded original parameters.

If this is the case, could we just remove the super().state_dict calls, recurse ourselves and call the post hook?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need the recursive calls for 1.) constructing the correct prefix, 2.) calling post_hook in a reversed order. We can do this ourselves but it will be better to reuse state_dict logic. The only thing that sharded_state_dict does not need is the detach one but that should not causing too many overheads.

"not be SUMMON_FULL_PARAMS."
)
with self._summon_full_params(recurse=False, writeback=False):
for fqn, _, _ in self._param_fqns:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wouldn't named_parameters() also just give the parameter names we need?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

named_parameters() will give us more parameters than what we need. We need to use recursive named_parameters() and this will give us more parameters: 1.) parameters in children FSDP modules 2.) parameters that are ignored by FSDP.

…to avoid OOM"

The original implementation put the call of `_summon_full_params()` in `state_dict()`. However, because `state_dict()` is recursive, `_summon_full_params()` will also behave like the recursive version even if recursive is set to False. This PR put the logic in the post hook to solve the OOM issue.

Differential Revision: [D38329396](https://our.internmc.facebook.com/intern/diff/D38329396/)

[ghstack-poisoned]
fegin added a commit that referenced this pull request Aug 2, 2022
Pull Request resolved: #82613

The original implementation put the call of `_summon_full_params()` in `state_dict()`. However, because `state_dict()` is recursive, `_summon_full_params()` will also behave like the recursive version even if recursive is set to False. This PR put the logic in the post hook to solve the OOM issue.
ghstack-source-id: 163330033

Differential Revision: [D38329396](https://our.internmc.facebook.com/intern/diff/D38329396/)
@fegin
Copy link
Contributor Author

fegin commented Aug 3, 2022

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a merge job. Check the current status here

@github-actions
Copy link
Contributor

github-actions bot commented Aug 3, 2022

Hey @fegin.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

facebook-github-bot pushed a commit that referenced this pull request Aug 4, 2022
#82613) (#82613)

Summary:
The original implementation put the call of `_summon_full_params()` in `state_dict()`. However, because `state_dict()` is recursive, `_summon_full_params()` will also behave like the recursive version even if recursive is set to False. This PR put the logic in the post hook to solve the OOM issue.

Pull Request resolved: #82613
Approved by: https://github.com/rohan-varma

Test Plan:
contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/b750c10fbe288a201e623e89473bd7ea0f485d56

Original Phabricator Test Plan:
CI

Reviewed By: rohan-varma

Differential Revision: D38329396

Pulled By: fegin

fbshipit-source-id: 2f560f9b7ba73ad515987a65a684076f605a7635
@facebook-github-bot facebook-github-bot deleted the gh/fegin/20/head branch August 7, 2022 14:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged oncall: distributed Add this issue/PR to distributed oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants