[FSDP][2/N] _summon_full_params -> _unshard_params#92297
[FSDP][2/N] _summon_full_params -> _unshard_params#92297awgu wants to merge 5 commits intogh/awgu/302/basefrom
_summon_full_params -> _unshard_params#92297Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/92297
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit c8420cc: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: 998f0f5 Pull Request resolved: pytorch#92297
**Overview**
This PR stack will add support for unsharding FSDP's sharded parameters for `fully_shard`. This PR takes the first step by doing some internal refactoring.
- The existing API for wrapper FSDP is the static method `summon_full_params()`, which calls into the helper `_summon_full_params()`.
- This PR refactors:
- `summon_full_params()` core logic to `_unshard_params()`
- `_summon_full_params()` to `_unshard_params_recurse()`, which has a `recurse: bool` argument
- Previous `_unshard_params()` to `_unshard_fsdp_state_params()`, which applies to a single FSDP state
**Details**
- This PR introduces `_get_fsdp_states_with_modules()` and `_get_root_fsdp_states_with_modules()`, which additionally return the modules along with the FSDP states. The modules are needed for handling `FlatParameter` registration.
- We may be able to remove this if we clean up the `use_orig_params=True` vs. `False` code paths because for `True`, the `FlatParameter` is not registered, meaning that it does not need to be de-registered.
- Since `fully_shard` requires `use_orig_params=True`, we may not need `_get_fsdp_states_with_modules()` and `_get_root_fsdp_root_modules()`; however, I prefer to make the separation of FSDP state and module explicit for now for clarity.
**Follow-Ups**
- `writeback=True` and `rank0_only=True` raises an error. The previous explanation was:
> is not supported, as model parameter shapes will be different across ranks, and writing to them can lead to inconsistencies across ranks when the context is exited.
I am not exactly sure what the different model parameter shapes refers to. However, I believe that we can support `writeback=True` and `rank0_only=True` by broadcasting the `FlatParameter` from rank 0 in the `finally`, writing back, and freeing. This should not increase the peak memory since rank 0 already holds the unsharded `FlatParameter` in GPU memory before writing back and nonzero ranks do not have any other unsharded `FlatParameter`s in GPU memory.
[ghstack-poisoned]
**Overview**
This PR stack will add support for unsharding FSDP's sharded parameters for `fully_shard`. This PR takes the first step by doing some internal refactoring.
- The existing API for wrapper FSDP is the static method `summon_full_params()`, which calls into the helper `_summon_full_params()`.
- This PR refactors:
- `summon_full_params()` core logic to `_unshard_params()`
- `_summon_full_params()` to `_unshard_params_recurse()`, which has a `recurse: bool` argument
- Previous `_unshard_params()` to `_unshard_fsdp_state_params()`, which applies to a single FSDP state
**Details**
- This PR introduces `_get_fsdp_states_with_modules()` and `_get_root_fsdp_states_with_modules()`, which additionally return the modules along with the FSDP states. The modules are needed for handling `FlatParameter` registration.
- We may be able to remove this if we clean up the `use_orig_params=True` vs. `False` code paths because for `True`, the `FlatParameter` is not registered, meaning that it does not need to be de-registered.
- Since `fully_shard` requires `use_orig_params=True`, we may not need `_get_fsdp_states_with_modules()` and `_get_root_fsdp_root_modules()`; however, I prefer to make the separation of FSDP state and module explicit for now for clarity.
**Follow-Ups**
- `writeback=True` and `rank0_only=True` raises an error. The previous explanation was:
> is not supported, as model parameter shapes will be different across ranks, and writing to them can lead to inconsistencies across ranks when the context is exited.
I am not exactly sure what the different model parameter shapes refers to. However, I believe that we can support `writeback=True` and `rank0_only=True` by broadcasting the `FlatParameter` from rank 0 in the `finally`, writing back, and freeing. This should not increase the peak memory since rank 0 already holds the unsharded `FlatParameter` in GPU memory before writing back and nonzero ranks do not have any other unsharded `FlatParameter`s in GPU memory.
[ghstack-poisoned]
**Overview**
This PR stack will add support for unsharding FSDP's sharded parameters for `fully_shard`. This PR takes the first step by doing some internal refactoring.
- The existing API for wrapper FSDP is the static method `summon_full_params()`, which calls into the helper `_summon_full_params()`.
- This PR refactors:
- `summon_full_params()` core logic to `_unshard_params()`
- `_summon_full_params()` to `_unshard_params_recurse()`, which has a `recurse: bool` argument
- Previous `_unshard_params()` to `_unshard_fsdp_state_params()`, which applies to a single FSDP state
**Details**
- This PR introduces `_get_fsdp_states_with_modules()` and `_get_root_fsdp_states_with_modules()`, which additionally return the modules along with the FSDP states. The modules are needed for handling `FlatParameter` registration.
- We may be able to remove this if we clean up the `use_orig_params=True` vs. `False` code paths because for `True`, the `FlatParameter` is not registered, meaning that it does not need to be de-registered.
- Since `fully_shard` requires `use_orig_params=True`, we may not need `_get_fsdp_states_with_modules()` and `_get_root_fsdp_root_modules()`; however, I prefer to make the separation of FSDP state and module explicit for now for clarity.
**Follow-Ups**
- `writeback=True` and `rank0_only=True` raises an error. The previous explanation was:
> is not supported, as model parameter shapes will be different across ranks, and writing to them can lead to inconsistencies across ranks when the context is exited.
I am not exactly sure what the different model parameter shapes refers to. However, I believe that we can support `writeback=True` and `rank0_only=True` by broadcasting the `FlatParameter` from rank 0 in the `finally`, writing back, and freeing. This should not increase the peak memory since rank 0 already holds the unsharded `FlatParameter` in GPU memory before writing back and nonzero ranks do not have any other unsharded `FlatParameter`s in GPU memory.
[ghstack-poisoned]
rohan-varma
left a comment
There was a problem hiding this comment.
Shall we add unittests for summon_full_params composable path?
| "to them can lead to inconsistencies across ranks when the " | ||
| "context is exited." | ||
| ) | ||
| # TODO: Rank 0 can broadcast the `FlatParameter` to allow all ranks to |
There was a problem hiding this comment.
could we file an issue for this? would it work for use_orig_params=True as well?
There was a problem hiding this comment.
I think it should work for both use_orig_params=True and False. I will file an issue.
| if recurse: | ||
| with contextlib.ExitStack() as stack: | ||
| # TODO (awgu): The traversal function does not traverse through | ||
| # incompatible composable APIs. Verify if this is the desired |
There was a problem hiding this comment.
Could you elaborate, what's an example of this?
There was a problem hiding this comment.
fully_shard(
Module(
replicate(
Submodule(
fully_shard(Subsubmodule),
Subsubmodule,
),
Submodule,
)
Because the traversal utils do not go through incompatible composable APIs (here, replicate), calling _unshard_params on the root Module will not unshard the parameters of the fully sharded Subsubmodule.
Yes, this has not been added yet. (I have a local [4/N] commit that does add a frontend for that path, but I did not open a PR for it since we have not finalized what the API should look like.) I will add tests when we include that. |
**Overview**
This PR stack will add support for unsharding FSDP's sharded parameters for `fully_shard`. This PR takes the first step by doing some internal refactoring.
- The existing API for wrapper FSDP is the static method `summon_full_params()`, which calls into the helper `_summon_full_params()`.
- This PR refactors:
- `summon_full_params()` core logic to `_unshard_params()`
- `_summon_full_params()` to `_unshard_params_recurse()`, which has a `recurse: bool` argument
- Previous `_unshard_params()` to `_unshard_fsdp_state_params()`, which applies to a single FSDP state
**Details**
- This PR introduces `_get_fsdp_states_with_modules()` and `_get_root_fsdp_states_with_modules()`, which additionally return the modules along with the FSDP states. The modules are needed for handling `FlatParameter` registration.
- We may be able to remove this if we clean up the `use_orig_params=True` vs. `False` code paths because for `True`, the `FlatParameter` is not registered, meaning that it does not need to be de-registered.
- Since `fully_shard` requires `use_orig_params=True`, we may not need `_get_fsdp_states_with_modules()` and `_get_root_fsdp_root_modules()`; however, I prefer to make the separation of FSDP state and module explicit for now for clarity.
**Follow-Ups**
- `writeback=True` and `rank0_only=True` raises an error. The previous explanation was:
> is not supported, as model parameter shapes will be different across ranks, and writing to them can lead to inconsistencies across ranks when the context is exited.
I am not exactly sure what the different model parameter shapes refers to. However, I believe that we can support `writeback=True` and `rank0_only=True` by broadcasting the `FlatParameter` from rank 0 in the `finally`, writing back, and freeing. This should not increase the peak memory since rank 0 already holds the unsharded `FlatParameter` in GPU memory before writing back and nonzero ranks do not have any other unsharded `FlatParameter`s in GPU memory.
[ghstack-poisoned]
…n-dev-setup * origin: (898 commits) Move dynamo.optimizations.distributed to backends (pytorch#93408) Remove cuda 11.6 from nightly (pytorch#93979) Refactor dynamo register_backend/BACKENDS (pytorch#93389) Remove cuda 11.6 from CI replace with 11.7 (pytorch#93406) [Dynamo] Rename `GuardBuilder.guarded_code` -> `check_fn_manager` (pytorch#93934) Revert "Remove CUDA 11.6 from nightly builds (pytorch#93404)" Revert "[inductor] fix crash issue when input is a view tensor (pytorch#90150)" Basic Validation for FSDP `state_dict` transformations of modules with persistent buffers (pytorch#93396) Merge Inductor perf smoke test with other inductor CI tests (pytorch#93395) [inductor] Don't import torchvision (pytorch#93027) [FSDP][3/N] Refactor `summon_full_params` unit tests (pytorch#92298) [FSDP][2/N] `_summon_full_params` -> `_unshard_params` (pytorch#92297) Remove CUDA 11.6 from nightly builds (pytorch#93404) Mark buffers that reuse other buffers (pytorch#93329) Refactor to allow reuse of SchedulerNode.allocate (pytorch#93328) retire sparse_mask_helper (pytorch#91714) update fbgemm third party (pytorch#93907) [inductor] fix crash issue when input is a view tensor (pytorch#90150) [Inductor] add config for weight prepacking (pytorch#93811) Check for none for NNModuleVariable.__module__ (pytorch#93326) ...
Stack from ghstack:
summon_full_paramsunit tests #92298 [FSDP][3/N] Refactorsummon_full_paramsunit tests_summon_full_params->_unshard_params#92297 [FSDP][2/N]_summon_full_params->_unshard_paramsOverview
This PR stack will add support for unsharding FSDP's sharded parameters for
fully_shard. This PR takes the first step by doing some internal refactoring.summon_full_params(), which calls into the helper_summon_full_params().summon_full_params()core logic to_unshard_params()_summon_full_params()to_unshard_params_recurse(), which has arecurse: boolargument_unshard_params()to_unshard_fsdp_state_params(), which applies to a single FSDP stateDetails
_get_fsdp_states_with_modules()and_get_root_fsdp_states_with_modules(), which additionally return the modules along with the FSDP states. The modules are needed for handlingFlatParameterregistration.use_orig_params=Truevs.Falsecode paths because forTrue, theFlatParameteris not registered, meaning that it does not need to be de-registered.fully_shardrequiresuse_orig_params=True, we may not need_get_fsdp_states_with_modules()and_get_root_fsdp_root_modules(); however, I prefer to make the separation of FSDP state and module explicit for now for clarity.Follow-Ups
writeback=Trueandrank0_only=Trueraises an error. The previous explanation was:I am not exactly sure what the different model parameter shapes refers to. However, I believe that we can support
writeback=Trueandrank0_only=Trueby broadcasting theFlatParameterfrom rank 0 in thefinally, writing back, and freeing. This should not increase the peak memory since rank 0 already holds the unshardedFlatParameterin GPU memory before writing back and nonzero ranks do not have any other unshardedFlatParameters in GPU memory.