[FSDP][7/N] Support replicate in fully_shard#91044
Closed
awgu wants to merge 17 commits intogh/awgu/283/basefrom
Closed
[FSDP][7/N] Support replicate in fully_shard#91044awgu wants to merge 17 commits intogh/awgu/283/basefrom
replicate in fully_shard#91044awgu wants to merge 17 commits intogh/awgu/283/basefrom
Conversation
[ghstack-poisoned]
This was referenced Dec 16, 2022
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/91044
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 570c2e1: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
[ghstack-poisoned]
[ghstack-poisoned]
This PR needs to be rebased to get the newly landed PRs. [ghstack-poisoned]
This PR needs to be rebased to get the newly landed PRs. I will update with a proper PR summary before requesting for review. [ghstack-poisoned]
This PR needs to be rebased to get the newly landed PRs. I will update with a proper PR summary before requesting for review. [ghstack-poisoned]
This PR needs to be rebased to get the newly landed PRs. I will update with a proper PR summary before requesting for review. [ghstack-poisoned]
awgu
commented
Dec 17, 2022
| @@ -93,9 +93,9 @@ def world_size(self): | |||
| def _dist_train(self): | |||
| rank = self.rank | |||
| world_size = self.world_size | |||
Collaborator
Author
There was a problem hiding this comment.
This test is disabled in CI, but I found that it is broken when running locally. This change fixes the test.
This PR needs to be rebased to get the newly landed PRs. I will update with a proper PR summary before requesting for review. [ghstack-poisoned]
This PR needs to be rebased to get the newly landed PRs. I will update with a proper PR summary before requesting for review. [ghstack-poisoned]
This PR supports nesting `replicate` in `fully_shard`.
- The PR achieves this by treating `replicate`-annotated modules are ignored modules. This means that all submodules in the `replicate`-annotated module's subtree are ignored, including nested `fully_shard`-annotated modules, which is the desired behavior.
---
This PR reworks some tree traversal.
One end goal is for `state._handles` to follow the same order for both the wrapper and composable paths. This implies that `_get_fsdp_handles()` returns the same value for both paths.
- The helper function `_get_fully_sharded_module_to_states()` now follows a left-to-right DFS from each fully sharded module instead of a BFS. The left-to-right DFS follows `.modules()` order.
- The composable auto "wrap" initialization function `_init_param_handles_from_module()` follows the reverse left-to-right DFS order. As noted in the code comments, this initialization order is a valid reverse topological sort, but it differs from the wrapper path. This is the _only_ difference with respect to initialization order through the entire process.
```
mod: Module(
submod1: Submodule()
submod2: Submodule(
subsubmod: Subsubmodule(),
),
)
```
For left-to-right DFS, the order is `mod`, `submod1`, `submod2`, `subsubmod`. (For context, right-to-left DFS would be `mod`, `submod2`, `subsubmod`, `submod1`. In other words, the left-to-right vs. right-to-left corresponds to `.children()` vs. `reversed(.children())` respectively.) Then, reverse left-to-right DFS is `subsubmod`, `submod2`, `submod1`, `mod`, which is a valid initialization order. However, the wrapper auto wrap initialization order would be `submod1`, `subsubmod`, `submod2`, `mod` since it directly follows a left-to-right DFS and initializes as a part of the recursive DFS logic.
- At the end of `_init_param_handles_from_module()`, we reverse the newly populated `state._handles`, so this is the reverse reverse left-to-right DFS order, which is equivalent to the left-to-right DFS order. Thus, `state._handles` has the same order for both paths.
Another goal is for `_get_fsdp_states()` to not traverse into any submodule that is annotated with an API that is not compatible with `fully_shard` (e.g. `replicate`). To achieve this while preserving that `_get_fsdp_states()` follows `.modules()` order, we again use a left-to-right DFS.
The reason the DFSs may look strange is because I implemented them non-recursively, which requires a stack.
- `test_get_fully_sharded_module_to_states()` in `test_utils.py` checks the traversal order of `_get_fully_sharded_module_to_states()`.
- `test_policy()` in `test_fully_shard.py` checks the traversal order returned by `_get_fsdp_handles()`.
---
Due to a circular dependency issue, we must move the graph/tree traversal helpers to their own file `_traversal_utils.py`, and any usages must import the entire file like `import torch.distributed.fsdp._traversal_utils as traversal_utils` instead of `from torch.distributed.fsdp._traversal_utils import ...`.
The cycle comes from the fact that the traversals require `_composable()`, which requires `_get_registry()` from `composable/contract.py`, which when imported, imports `composable/fully_shard.py`, which requires the traversals.
[ghstack-poisoned]
This PR supports nesting `replicate` in `fully_shard`.
- The PR achieves this by treating `replicate`-annotated modules are ignored modules. This means that all submodules in the `replicate`-annotated module's subtree are ignored, including nested `fully_shard`-annotated modules, which is the desired behavior.
---
This PR reworks some tree traversal.
One end goal is for `state._handles` to follow the same order for both the wrapper and composable paths. This implies that `_get_fsdp_handles()` returns the same value for both paths.
- The helper function `_get_fully_sharded_module_to_states()` now follows a left-to-right DFS from each fully sharded module instead of a BFS. The left-to-right DFS follows `.modules()` order.
- The composable auto "wrap" initialization function `_init_param_handles_from_module()` follows the reverse left-to-right DFS order. As noted in the code comments, this initialization order is a valid reverse topological sort, but it differs from the wrapper path. This is the _only_ difference with respect to initialization order through the entire process.
```
mod: Module(
submod1: Submodule()
submod2: Submodule(
subsubmod: Subsubmodule(),
),
)
```
For left-to-right DFS, the order is `mod`, `submod1`, `submod2`, `subsubmod`. (For context, right-to-left DFS would be `mod`, `submod2`, `subsubmod`, `submod1`. In other words, the left-to-right vs. right-to-left corresponds to `.children()` vs. `reversed(.children())` respectively.) Then, reverse left-to-right DFS is `subsubmod`, `submod2`, `submod1`, `mod`, which is a valid initialization order. However, the wrapper auto wrap initialization order would be `submod1`, `subsubmod`, `submod2`, `mod` since it directly follows a left-to-right DFS and initializes as a part of the recursive DFS logic.
- At the end of `_init_param_handles_from_module()`, we reverse the newly populated `state._handles`, so this is the reverse reverse left-to-right DFS order, which is equivalent to the left-to-right DFS order. Thus, `state._handles` has the same order for both paths.
Another goal is for `_get_fsdp_states()` to not traverse into any submodule that is annotated with an API that is not compatible with `fully_shard` (e.g. `replicate`). To achieve this while preserving that `_get_fsdp_states()` follows `.modules()` order, we again use a left-to-right DFS.
The reason the DFSs may look strange is because I implemented them non-recursively, which requires a stack.
- `test_get_fully_sharded_module_to_states()` in `test_utils.py` checks the traversal order of `_get_fully_sharded_module_to_states()`.
- `test_policy()` in `test_fully_shard.py` checks the traversal order returned by `_get_fsdp_handles()`.
---
Due to a circular dependency issue, we must move the graph/tree traversal helpers to their own file `_traversal_utils.py`, and any usages must import the entire file like `import torch.distributed.fsdp._traversal_utils as traversal_utils` instead of `from torch.distributed.fsdp._traversal_utils import ...`.
The cycle comes from the fact that the traversals require `_composable()`, which requires `_get_registry()` from `composable/contract.py`, which when imported, imports `composable/fully_shard.py`, which requires the traversals.
[ghstack-poisoned]
Collaborator
Author
|
@pytorchbot merge |
Collaborator
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This was referenced Dec 20, 2022
This was referenced Jan 5, 2023
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Stack from ghstack:
replicateinfully_shard#91044 [FSDP][7/N] Supportreplicateinfully_shard_FSDPStatetraversal #90959 [FSDP][6/N] Add note explaining idioms for_FSDPStatetraversalfully_shard#90874 [FSDP][5/N] Add manual "wrapping" support forfully_shardThis PR supports nesting
replicateinfully_shard.replicate-annotated modules are ignored modules. This means that all submodules in thereplicate-annotated module's subtree are ignored, including nestedfully_shard-annotated modules, which is the desired behavior.This PR reworks some tree traversal.
One end goal is for
state._handlesto follow the same order for both the wrapper and composable paths. This implies that_get_fsdp_handles()returns the same value for both paths._get_fully_sharded_module_to_states()now follows a left-to-right DFS from each fully sharded module instead of a BFS. The left-to-right DFS follows.modules()order._init_param_handles_from_module()follows the reverse left-to-right DFS order. As noted in the code comments, this initialization order is a valid reverse topological sort, but it differs from the wrapper path. This is the only difference with respect to initialization order through the entire process.For left-to-right DFS, the order is
mod,submod1,submod2,subsubmod. (For context, right-to-left DFS would bemod,submod2,subsubmod,submod1. In other words, the left-to-right vs. right-to-left corresponds to.children()vs.reversed(.children())respectively.) Then, reverse left-to-right DFS issubsubmod,submod2,submod1,mod, which is a valid initialization order. However, the wrapper auto wrap initialization order would besubmod1,subsubmod,submod2,modsince it directly follows a left-to-right DFS and initializes as a part of the recursive DFS logic._init_param_handles_from_module(), we reverse the newly populatedstate._handles, so this is the reverse reverse left-to-right DFS order, which is equivalent to the left-to-right DFS order. Thus,state._handleshas the same order for both paths.Another goal is for
_get_fsdp_states()to not traverse into any submodule that is annotated with an API that is not compatible withfully_shard(e.g.replicate). To achieve this while preserving that_get_fsdp_states()follows.modules()order, we again use a left-to-right DFS.The reason the DFSs may look strange is because I implemented them non-recursively, which requires a stack.
test_get_fully_sharded_module_to_states()intest_utils.pychecks the traversal order of_get_fully_sharded_module_to_states().test_policy()intest_fully_shard.pychecks the traversal order returned by_get_fsdp_handles().Due to a circular dependency issue, we must move the graph/tree traversal helpers to their own file
_traversal_utils.py, and any usages must import the entire file likeimport torch.distributed.fsdp._traversal_utils as traversal_utilsinstead offrom torch.distributed.fsdp._traversal_utils import ....The cycle comes from the fact that the traversals require
_composable(), which requires_get_registry()fromcomposable/contract.py, which when imported, importscomposable/fully_shard.py, which requires the traversals.