[FSDP][5/N] Add manual "wrapping" support for fully_shard#90874
[FSDP][5/N] Add manual "wrapping" support for fully_shard#90874awgu wants to merge 13 commits intogh/awgu/279/basefrom
fully_shard#90874Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/90874
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 1924517: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This is not ready for review. [ghstack-poisoned]
This is not ready for review. [ghstack-poisoned]
This is not ready for review. [ghstack-poisoned]
This is not ready for review. [ghstack-poisoned]
This is not ready for review. [ghstack-poisoned]
| flat_param = handle.flat_param | ||
| if flat_param not in self.flat_param_to_prefixed_param_names: | ||
| if flat_param not in self.param_to_fqn: | ||
| continue |
There was a problem hiding this comment.
when is a handle invalid?
There was a problem hiding this comment.
I traced through the history of my PRs, and it looks like this check was added arbitrarily. However, reasoning about it now, I feel like this check is important for execution to not crash for use_orig_params=True since self.param_to_fqn is constructed via _get_param_to_fqns(root_module) and for use_orig_params=True, _get_param_to_fqns does not include any FlatParameters since they are not registered.
I think _exec_order_utils.py needs to be revisited for use_orig_params=True.
|
|
||
| states = [state] if _is_composable(state) else _get_fsdp_states(state) | ||
| for state in states: | ||
| for state in _get_fsdp_states(module): |
There was a problem hiding this comment.
hmm, is this intentional to abuse the var name state?
There was a problem hiding this comment.
and on line 887, is it a bug? I assume your intention is to set it on the original state var instead of the one used in the loop?
There was a problem hiding this comment.
Oh man. Thanks for the great catch. Let me fix this.
This PR adds manual "wrapping" support for `fully_shard`. For example, for ``` fully_shard(mod.sub) fully_shard(mod) ``` `mod.sub` and `mod` will share the same FSDP data structures. To have parity with wrapper FSDP, this PR only checks support for when each manual application of `fully_shard` passes `policy=None`. Hybrid auto / manual wrapping is not in scope for this PR since it is not supported for wrapper FSDP either. I can follow up to either add support properly or raise and error early. [ghstack-poisoned]
ghstack-source-id: 4e1e0b4 Pull Request resolved: pytorch#90874
This PR adds manual "wrapping" support for `fully_shard`. For example, for ``` fully_shard(mod.sub) fully_shard(mod) ``` `mod.sub` and `mod` will share the same FSDP data structures. To have parity with wrapper FSDP, this PR only checks support for when each manual application of `fully_shard` passes `policy=None`. Hybrid auto / manual wrapping is not in scope for this PR since it is not supported for wrapper FSDP either. I can follow up to either add support properly or raise and error early. [ghstack-poisoned]
This PR adds manual "wrapping" support for `fully_shard`. For example, for ``` fully_shard(mod.sub) fully_shard(mod) ``` `mod.sub` and `mod` will share the same FSDP data structures. To have parity with wrapper FSDP, this PR only checks support for when each manual application of `fully_shard` passes `policy=None`. Hybrid auto / manual wrapping is not in scope for this PR since it is not supported for wrapper FSDP either. I can follow up to either add support properly or raise and error early. [ghstack-poisoned]
This PR adds manual "wrapping" support for `fully_shard`. For example, for ``` fully_shard(mod.sub) fully_shard(mod) ``` `mod.sub` and `mod` will share the same FSDP data structures. To have parity with wrapper FSDP, this PR only checks support for when each manual application of `fully_shard` passes `policy=None`. Hybrid auto / manual wrapping is not in scope for this PR since it is not supported for wrapper FSDP either. I can follow up to either add support properly or raise and error early. [ghstack-poisoned]
ghstack-source-id: 03633b9 Pull Request resolved: pytorch#90874
This PR adds manual "wrapping" support for `fully_shard`. For example, for ``` fully_shard(mod.sub) fully_shard(mod) ``` `mod.sub` and `mod` will share the same FSDP data structures. To have parity with wrapper FSDP, this PR only checks support for when each manual application of `fully_shard` passes `policy=None`. Hybrid auto / manual wrapping is not in scope for this PR since it is not supported for wrapper FSDP either. I can follow up to either add support properly or raise and error early. [ghstack-poisoned]
|
This pull request has been merged in 32fde53. |
ghstack-source-id: 97a14bc Pull Request resolved: pytorch#90874
Stack from ghstack:
replicateinfully_shard#91044 [FSDP][7/N] Supportreplicateinfully_shard_FSDPStatetraversal #90959 [FSDP][6/N] Add note explaining idioms for_FSDPStatetraversalfully_shard#90874 [FSDP][5/N] Add manual "wrapping" support forfully_shardThis PR adds manual "wrapping" support for
fully_shard. For example, formod.subandmodwill share the same FSDP data structures.To have parity with wrapper FSDP, this PR only checks support for when each manual application of
fully_shardpassespolicy=None. Hybrid auto / manual wrapping is not in scope for this PR since it is not supported for wrapper FSDP either. I can follow up to either add support properly or raise and error early.