Skip to content

[FSDP][5/N] Add manual "wrapping" support for fully_shard#90874

Closed
awgu wants to merge 13 commits intogh/awgu/279/basefrom
gh/awgu/279/head
Closed

[FSDP][5/N] Add manual "wrapping" support for fully_shard#90874
awgu wants to merge 13 commits intogh/awgu/279/basefrom
gh/awgu/279/head

Conversation

@awgu
Copy link
Collaborator

@awgu awgu commented Dec 14, 2022

Stack from ghstack:

This PR adds manual "wrapping" support for fully_shard. For example, for

fully_shard(mod.sub)
fully_shard(mod)

mod.sub and mod will share the same FSDP data structures.

To have parity with wrapper FSDP, this PR only checks support for when each manual application of fully_shard passes policy=None. Hybrid auto / manual wrapping is not in scope for this PR since it is not supported for wrapper FSDP either. I can follow up to either add support properly or raise and error early.

@awgu awgu mentioned this pull request Dec 14, 2022
@pytorch-bot
Copy link

pytorch-bot bot commented Dec 14, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/90874

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 1924517:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

This is not ready for review.

[ghstack-poisoned]
awgu pushed a commit that referenced this pull request Dec 15, 2022
This is not ready for review.

[ghstack-poisoned]
awgu pushed a commit that referenced this pull request Dec 15, 2022
This is not ready for review.

[ghstack-poisoned]
This is not ready for review.

[ghstack-poisoned]
awgu pushed a commit that referenced this pull request Dec 15, 2022
This is not ready for review.

[ghstack-poisoned]
flat_param = handle.flat_param
if flat_param not in self.flat_param_to_prefixed_param_names:
if flat_param not in self.param_to_fqn:
continue
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when is a handle invalid?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I traced through the history of my PRs, and it looks like this check was added arbitrarily. However, reasoning about it now, I feel like this check is important for execution to not crash for use_orig_params=True since self.param_to_fqn is constructed via _get_param_to_fqns(root_module) and for use_orig_params=True, _get_param_to_fqns does not include any FlatParameters since they are not registered.

I think _exec_order_utils.py needs to be revisited for use_orig_params=True.


states = [state] if _is_composable(state) else _get_fsdp_states(state)
for state in states:
for state in _get_fsdp_states(module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, is this intentional to abuse the var name state?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and on line 887, is it a bug? I assume your intention is to set it on the original state var instead of the one used in the loop?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh man. Thanks for the great catch. Let me fix this.

This PR adds manual "wrapping" support for `fully_shard`. For example, for
```
fully_shard(mod.sub)
fully_shard(mod)
```
`mod.sub` and `mod` will share the same FSDP data structures.

To have parity with wrapper FSDP, this PR only checks support for when each manual application of `fully_shard` passes `policy=None`. Hybrid auto / manual wrapping is not in scope for this PR since it is not supported for wrapper FSDP either. I can follow up to either add support properly or raise and error early.




[ghstack-poisoned]
Copy link
Contributor

@mrshenli mrshenli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

awgu pushed a commit to awgu/pytorch that referenced this pull request Dec 17, 2022
This PR adds manual "wrapping" support for `fully_shard`. For example, for
```
fully_shard(mod.sub)
fully_shard(mod)
```
`mod.sub` and `mod` will share the same FSDP data structures.

To have parity with wrapper FSDP, this PR only checks support for when each manual application of `fully_shard` passes `policy=None`. Hybrid auto / manual wrapping is not in scope for this PR since it is not supported for wrapper FSDP either. I can follow up to either add support properly or raise and error early.




[ghstack-poisoned]
This PR adds manual "wrapping" support for `fully_shard`. For example, for
```
fully_shard(mod.sub)
fully_shard(mod)
```
`mod.sub` and `mod` will share the same FSDP data structures.

To have parity with wrapper FSDP, this PR only checks support for when each manual application of `fully_shard` passes `policy=None`. Hybrid auto / manual wrapping is not in scope for this PR since it is not supported for wrapper FSDP either. I can follow up to either add support properly or raise and error early.




[ghstack-poisoned]
This PR adds manual "wrapping" support for `fully_shard`. For example, for
```
fully_shard(mod.sub)
fully_shard(mod)
```
`mod.sub` and `mod` will share the same FSDP data structures.

To have parity with wrapper FSDP, this PR only checks support for when each manual application of `fully_shard` passes `policy=None`. Hybrid auto / manual wrapping is not in scope for this PR since it is not supported for wrapper FSDP either. I can follow up to either add support properly or raise and error early.




[ghstack-poisoned]
awgu pushed a commit to awgu/pytorch that referenced this pull request Dec 19, 2022
This PR adds manual "wrapping" support for `fully_shard`. For example, for
```
fully_shard(mod.sub)
fully_shard(mod)
```
`mod.sub` and `mod` will share the same FSDP data structures.

To have parity with wrapper FSDP, this PR only checks support for when each manual application of `fully_shard` passes `policy=None`. Hybrid auto / manual wrapping is not in scope for this PR since it is not supported for wrapper FSDP either. I can follow up to either add support properly or raise and error early.




[ghstack-poisoned]
@awgu awgu added the ciflow/trunk Trigger trunk jobs on your pull request label Dec 20, 2022
@facebook-github-bot
Copy link
Contributor

This pull request has been merged in 32fde53.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: distributed (fsdp) release notes category topic: improvements topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants