Skip to content

[FSDP][7/N] Support replicate in fully_shard#91044

Closed
awgu wants to merge 17 commits intogh/awgu/283/basefrom
gh/awgu/283/head
Closed

[FSDP][7/N] Support replicate in fully_shard#91044
awgu wants to merge 17 commits intogh/awgu/283/basefrom
gh/awgu/283/head

Conversation

@awgu
Copy link
Collaborator

@awgu awgu commented Dec 16, 2022

Stack from ghstack:

This PR supports nesting replicate in fully_shard.

  • The PR achieves this by treating replicate-annotated modules are ignored modules. This means that all submodules in the replicate-annotated module's subtree are ignored, including nested fully_shard-annotated modules, which is the desired behavior.

This PR reworks some tree traversal.

One end goal is for state._handles to follow the same order for both the wrapper and composable paths. This implies that _get_fsdp_handles() returns the same value for both paths.

  • The helper function _get_fully_sharded_module_to_states() now follows a left-to-right DFS from each fully sharded module instead of a BFS. The left-to-right DFS follows .modules() order.
  • The composable auto "wrap" initialization function _init_param_handles_from_module() follows the reverse left-to-right DFS order. As noted in the code comments, this initialization order is a valid reverse topological sort, but it differs from the wrapper path. This is the only difference with respect to initialization order through the entire process.
mod: Module(
    submod1: Submodule()
    submod2: Submodule(
        subsubmod: Subsubmodule(),
    ),
)

For left-to-right DFS, the order is mod, submod1, submod2, subsubmod. (For context, right-to-left DFS would be mod, submod2, subsubmod, submod1. In other words, the left-to-right vs. right-to-left corresponds to .children() vs. reversed(.children()) respectively.) Then, reverse left-to-right DFS is subsubmod, submod2, submod1, mod, which is a valid initialization order. However, the wrapper auto wrap initialization order would be submod1, subsubmod, submod2, mod since it directly follows a left-to-right DFS and initializes as a part of the recursive DFS logic.

  • At the end of _init_param_handles_from_module(), we reverse the newly populated state._handles, so this is the reverse reverse left-to-right DFS order, which is equivalent to the left-to-right DFS order. Thus, state._handles has the same order for both paths.

Another goal is for _get_fsdp_states() to not traverse into any submodule that is annotated with an API that is not compatible with fully_shard (e.g. replicate). To achieve this while preserving that _get_fsdp_states() follows .modules() order, we again use a left-to-right DFS.

The reason the DFSs may look strange is because I implemented them non-recursively, which requires a stack.

  • test_get_fully_sharded_module_to_states() in test_utils.py checks the traversal order of _get_fully_sharded_module_to_states().
  • test_policy() in test_fully_shard.py checks the traversal order returned by _get_fsdp_handles().

Due to a circular dependency issue, we must move the graph/tree traversal helpers to their own file _traversal_utils.py, and any usages must import the entire file like import torch.distributed.fsdp._traversal_utils as traversal_utils instead of from torch.distributed.fsdp._traversal_utils import ....

The cycle comes from the fact that the traversals require _composable(), which requires _get_registry() from composable/contract.py, which when imported, imports composable/fully_shard.py, which requires the traversals.

@pytorch-bot
Copy link

pytorch-bot bot commented Dec 16, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/91044

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 570c2e1:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@yhcharles yhcharles self-requested a review December 16, 2022 23:07
awgu pushed a commit that referenced this pull request Dec 17, 2022
ghstack-source-id: 4dc6b99
Pull Request resolved: #91044
This PR needs to be rebased to get the newly landed PRs.

[ghstack-poisoned]
This PR needs to be rebased to get the newly landed PRs.

I will update with a proper PR summary before requesting for review.

[ghstack-poisoned]
This PR needs to be rebased to get the newly landed PRs.

I will update with a proper PR summary before requesting for review.

[ghstack-poisoned]
This PR needs to be rebased to get the newly landed PRs.

I will update with a proper PR summary before requesting for review.

[ghstack-poisoned]
@@ -93,9 +93,9 @@ def world_size(self):
def _dist_train(self):
rank = self.rank
world_size = self.world_size
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is disabled in CI, but I found that it is broken when running locally. This change fixes the test.

This PR needs to be rebased to get the newly landed PRs.

I will update with a proper PR summary before requesting for review.

[ghstack-poisoned]
This PR needs to be rebased to get the newly landed PRs.

I will update with a proper PR summary before requesting for review.

[ghstack-poisoned]
awgu pushed a commit that referenced this pull request Dec 19, 2022
ghstack-source-id: 960f26e
Pull Request resolved: #91044
This PR supports nesting `replicate` in `fully_shard`.
- The PR achieves this by treating `replicate`-annotated modules are ignored modules. This means that all submodules in the `replicate`-annotated module's subtree are ignored, including nested `fully_shard`-annotated modules, which is the desired behavior.

---

This PR reworks some tree traversal.

One end goal is for `state._handles` to follow the same order for both the wrapper and composable paths. This implies that `_get_fsdp_handles()` returns the same value for both paths.
- The helper function `_get_fully_sharded_module_to_states()` now follows a left-to-right DFS from each fully sharded module instead of a BFS. The left-to-right DFS follows `.modules()` order.
- The composable auto "wrap" initialization function `_init_param_handles_from_module()` follows the reverse left-to-right DFS order. As noted in the code comments, this initialization order is a valid reverse topological sort, but it differs from the wrapper path. This is the _only_ difference with respect to initialization order through the entire process.
```
mod: Module(
    submod1: Submodule()
    submod2: Submodule(
        subsubmod: Subsubmodule(),
    ),
)
```
For left-to-right DFS, the order is `mod`, `submod1`, `submod2`, `subsubmod`. (For context, right-to-left DFS would be `mod`, `submod2`, `subsubmod`, `submod1`. In other words, the left-to-right vs. right-to-left corresponds to `.children()` vs. `reversed(.children())` respectively.) Then, reverse left-to-right DFS is `subsubmod`, `submod2`, `submod1`, `mod`, which is a valid initialization order. However, the wrapper auto wrap initialization order would be `submod1`, `subsubmod`, `submod2`, `mod` since it directly follows a left-to-right DFS and initializes as a part of the recursive DFS logic.
- At the end of `_init_param_handles_from_module()`, we reverse the newly populated `state._handles`, so this is the reverse reverse left-to-right DFS order, which is equivalent to the left-to-right DFS order. Thus, `state._handles` has the same order for both paths.

Another goal is for `_get_fsdp_states()` to not traverse into any submodule that is annotated with an API that is not compatible with `fully_shard` (e.g. `replicate`). To achieve this while preserving that `_get_fsdp_states()` follows `.modules()` order, we again use a left-to-right DFS.

The reason the DFSs may look strange is because I implemented them non-recursively, which requires a stack.

- `test_get_fully_sharded_module_to_states()` in `test_utils.py` checks the traversal order of `_get_fully_sharded_module_to_states()`.
- `test_policy()` in `test_fully_shard.py` checks the traversal order returned by `_get_fsdp_handles()`.

---

Due to a circular dependency issue, we must move the graph/tree traversal helpers to their own file `_traversal_utils.py`, and any usages must import the entire file like `import torch.distributed.fsdp._traversal_utils as traversal_utils` instead of `from torch.distributed.fsdp._traversal_utils import ...`.

The cycle comes from the fact that the traversals require `_composable()`, which requires `_get_registry()` from `composable/contract.py`, which when imported, imports `composable/fully_shard.py`, which requires the traversals.

[ghstack-poisoned]
awgu pushed a commit that referenced this pull request Dec 19, 2022
ghstack-source-id: 0393d1a
Pull Request resolved: #91044
This PR supports nesting `replicate` in `fully_shard`.
- The PR achieves this by treating `replicate`-annotated modules are ignored modules. This means that all submodules in the `replicate`-annotated module's subtree are ignored, including nested `fully_shard`-annotated modules, which is the desired behavior.

---

This PR reworks some tree traversal.

One end goal is for `state._handles` to follow the same order for both the wrapper and composable paths. This implies that `_get_fsdp_handles()` returns the same value for both paths.
- The helper function `_get_fully_sharded_module_to_states()` now follows a left-to-right DFS from each fully sharded module instead of a BFS. The left-to-right DFS follows `.modules()` order.
- The composable auto "wrap" initialization function `_init_param_handles_from_module()` follows the reverse left-to-right DFS order. As noted in the code comments, this initialization order is a valid reverse topological sort, but it differs from the wrapper path. This is the _only_ difference with respect to initialization order through the entire process.
```
mod: Module(
    submod1: Submodule()
    submod2: Submodule(
        subsubmod: Subsubmodule(),
    ),
)
```
For left-to-right DFS, the order is `mod`, `submod1`, `submod2`, `subsubmod`. (For context, right-to-left DFS would be `mod`, `submod2`, `subsubmod`, `submod1`. In other words, the left-to-right vs. right-to-left corresponds to `.children()` vs. `reversed(.children())` respectively.) Then, reverse left-to-right DFS is `subsubmod`, `submod2`, `submod1`, `mod`, which is a valid initialization order. However, the wrapper auto wrap initialization order would be `submod1`, `subsubmod`, `submod2`, `mod` since it directly follows a left-to-right DFS and initializes as a part of the recursive DFS logic.
- At the end of `_init_param_handles_from_module()`, we reverse the newly populated `state._handles`, so this is the reverse reverse left-to-right DFS order, which is equivalent to the left-to-right DFS order. Thus, `state._handles` has the same order for both paths.

Another goal is for `_get_fsdp_states()` to not traverse into any submodule that is annotated with an API that is not compatible with `fully_shard` (e.g. `replicate`). To achieve this while preserving that `_get_fsdp_states()` follows `.modules()` order, we again use a left-to-right DFS.

The reason the DFSs may look strange is because I implemented them non-recursively, which requires a stack.

- `test_get_fully_sharded_module_to_states()` in `test_utils.py` checks the traversal order of `_get_fully_sharded_module_to_states()`.
- `test_policy()` in `test_fully_shard.py` checks the traversal order returned by `_get_fsdp_handles()`.

---

Due to a circular dependency issue, we must move the graph/tree traversal helpers to their own file `_traversal_utils.py`, and any usages must import the entire file like `import torch.distributed.fsdp._traversal_utils as traversal_utils` instead of `from torch.distributed.fsdp._traversal_utils import ...`.

The cycle comes from the fact that the traversals require `_composable()`, which requires `_get_registry()` from `composable/contract.py`, which when imported, imports `composable/fully_shard.py`, which requires the traversals.

[ghstack-poisoned]
awgu pushed a commit that referenced this pull request Dec 20, 2022
ghstack-source-id: 4b369f3
Pull Request resolved: #91044
@awgu awgu added the ciflow/trunk Trigger trunk jobs on your pull request label Dec 20, 2022
@awgu
Copy link
Collaborator Author

awgu commented Dec 20, 2022

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: distributed (fsdp) release notes category topic: improvements topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants