[DeviceMesh] Prefer using _layout over _mesh for all sorts of things#165554
[DeviceMesh] Prefer using _layout over _mesh for all sorts of things#165554lw wants to merge 7 commits intogh/lw/8/basefrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/165554
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit c07f577 with merge base 5d4da26 ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
torch/distributed/device_mesh.py
Outdated
|
|
||
| _device_type: str | ||
| _mesh: torch.Tensor | ||
| _global_rank_permutation: torch.Tensor |
There was a problem hiding this comment.
As I mentioned, I welcome suggestions for better names.
There was a problem hiding this comment.
I personally don't like the idea of permutation. Shall we call it _global_ranks_alloc?
There was a problem hiding this comment.
Why alloc (which I guess stand for allocation)? In which sense is this an allocation?
An alternative name could also be _rank_order, which has the benefit of being shorter.
There was a problem hiding this comment.
Because it means rank allocation? The actual mesh tensor [0, 3, 5, 10] (let's say a random tensor) it means a device allocation. It refers to a device representation. That's why I think allocation might be better? I mean I don't have a strong opinion on the word allocation, and I just want to make sure the name here clearly represents what it means
There was a problem hiding this comment.
I thought about what could be the shortest possible name (to keep things concise), and I like _rank_map, which I find also quite self-descriptive, since this is exactly what we're using this for: to map coordinates to actual ranks. It's also the term we're already using in remap_to_tensor.
Are people ok with this name?
There was a problem hiding this comment.
_rank_map sounds good to me.
torch/distributed/device_mesh.py
Outdated
| self._global_rank_permutation = ( | ||
| _root_mesh._global_rank_permutation | ||
| if _root_mesh is not None | ||
| else mesh_tensor.flatten() | ||
| ) |
There was a problem hiding this comment.
This is a bit of a mess, but the next PR in the stack will clean this up.
| @@ -258,7 +266,13 @@ def device_type(self) -> str: | |||
| @property | |||
| def mesh(self) -> torch.Tensor: | |||
There was a problem hiding this comment.
This function is admittedly quite complicated. I think it has been a good exercise to write it out as it helps us understand what are the contents of the mesh Tensor. If we agree on this definition, we can work next on cleaning this up :)
torch/distributed/device_mesh.py
Outdated
| if full_mesh.size(0) == 1: | ||
| return full_mesh[0] |
There was a problem hiding this comment.
This case here is for DeviceMeshes that do not contain the rank on which they're being instantiated.
In general, I have a hard time understanding in which cases they would be useful, and I'd like to explore the possibility of forbidding such a scenario.
Concretely, this codepath is currently being triggered by get_all_submeshes, thus I'd like to look into whether we can rework the internal usages of that private method and hopefully remove it or improve it.
There was a problem hiding this comment.
Agree we can also make get_all_submeshes not rely on this case..
There was a problem hiding this comment.
One thing is required, can you kindly run DTensor CPU overhead benchmark here: #159169. At least we have an understanding of how much overhead this change will bring into.
fduwjj
left a comment
There was a problem hiding this comment.
Nice, this looks make sense to me.. Thanks for continuing helping make device mesh cleaner. Just one question for naming and DTensor CPU overhead.
|
I ran the CPU benchmark. I looked at the iteration_4/rank0_trace.json files, and measured the e2e time of the first forward on CPU. Before this PR it took 1.526s, after this PR it took 1.521s. Hence no regression. |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…om_ranks (#165555) The refactoring of DeviceMesh is heavily constrained by the signature of its constructor, which is a public API which contains some "legacy" concepts which we'd love to get rid of, such as an explicit/materialized `mesh` Tensor. In other languages the solution to this would be to add a private overload of the constructor. Python doesn't natively allow this, but in this PR I managed to build something that approximates it. This new private constructor basically only takes `_layout`, `_global_rank_permutation`, and `mesh_dim_names`. With such a constructor we can effectively simplify a lot of callsites and get rid of the `_create_mesh_from_ranks` helper method. That's a good thing because it was instantiating many DeviceMeshes in a for loop, which always felt unnecessary. Pull Request resolved: #165555 Approved by: https://github.com/fduwjj, https://github.com/fegin ghstack dependencies: #165554
By adding a few small helpers (e.g., a `splice` method to `_MeshLayout`, and making `_init_process_groups` static and thus stateless) we can substantially shorten the definition of the unflatten method, and help readability. Pull Request resolved: #165556 Approved by: https://github.com/fduwjj ghstack dependencies: #165554, #165555
|
somehow this test failed: python test/distributed/test_serialization.py TestSerialization.test_dtensor |
|
Starting merge as part of PR stack under #165556 |
…om_ranks (#165555) The refactoring of DeviceMesh is heavily constrained by the signature of its constructor, which is a public API which contains some "legacy" concepts which we'd love to get rid of, such as an explicit/materialized `mesh` Tensor. In other languages the solution to this would be to add a private overload of the constructor. Python doesn't natively allow this, but in this PR I managed to build something that approximates it. This new private constructor basically only takes `_layout`, `_global_rank_permutation`, and `mesh_dim_names`. With such a constructor we can effectively simplify a lot of callsites and get rid of the `_create_mesh_from_ranks` helper method. That's a good thing because it was instantiating many DeviceMeshes in a for loop, which always felt unnecessary. Pull Request resolved: #165555 Approved by: https://github.com/fduwjj, https://github.com/fegin ghstack dependencies: #165554
By adding a few small helpers (e.g., a `splice` method to `_MeshLayout`, and making `_init_process_groups` static and thus stateless) we can substantially shorten the definition of the unflatten method, and help readability. Pull Request resolved: #165556 Approved by: https://github.com/fduwjj ghstack dependencies: #165554, #165555
…ytorch#165554) The goal of this PR is to avoid storing the explicit `mesh` Tensor inside each DeviceMesh, and instead compute it on-the-fly when the end user needs it, and try to replace all of its internal usages with `_layout` and the newly-introduced `_global_rank_permutation` Tensor. The name of this attribute is up for debate. The advantage of the `_global_rank_permutation` Tensor is that it is _the same_ Tensor for the root mesh and all its children, so it doesn't need to be copied/reallocated. Pull Request resolved: pytorch#165554 Approved by: https://github.com/fduwjj
…om_ranks (pytorch#165555) The refactoring of DeviceMesh is heavily constrained by the signature of its constructor, which is a public API which contains some "legacy" concepts which we'd love to get rid of, such as an explicit/materialized `mesh` Tensor. In other languages the solution to this would be to add a private overload of the constructor. Python doesn't natively allow this, but in this PR I managed to build something that approximates it. This new private constructor basically only takes `_layout`, `_global_rank_permutation`, and `mesh_dim_names`. With such a constructor we can effectively simplify a lot of callsites and get rid of the `_create_mesh_from_ranks` helper method. That's a good thing because it was instantiating many DeviceMeshes in a for loop, which always felt unnecessary. Pull Request resolved: pytorch#165555 Approved by: https://github.com/fduwjj, https://github.com/fegin ghstack dependencies: pytorch#165554
By adding a few small helpers (e.g., a `splice` method to `_MeshLayout`, and making `_init_process_groups` static and thus stateless) we can substantially shorten the definition of the unflatten method, and help readability. Pull Request resolved: pytorch#165556 Approved by: https://github.com/fduwjj ghstack dependencies: pytorch#165554, pytorch#165555
This reverts commit 86fd4fc. Reverted pytorch#165556 on behalf of https://github.com/malfet due to Looks like it broke serialization test, see https://hud.pytorch.org/hud/pytorch/pytorch/aba8c43594a83772281a62a7961c0b6ddcff321d/1?per_page=50&name_filter=distributed%2C%201&mergeEphemeralLF=true ([comment](pytorch#165554 (comment)))
…_mesh_from_ranks (pytorch#165555)" This reverts commit 99097b6. Reverted pytorch#165555 on behalf of https://github.com/malfet due to Looks like it broke serialization test, see https://hud.pytorch.org/hud/pytorch/pytorch/aba8c43594a83772281a62a7961c0b6ddcff321d/1?per_page=50&name_filter=distributed%2C%201&mergeEphemeralLF=true ([comment](pytorch#165554 (comment)))
… things (pytorch#165554)" This reverts commit d61a9b8. Reverted pytorch#165554 on behalf of https://github.com/malfet due to Looks like it broke serialization test, see https://hud.pytorch.org/hud/pytorch/pytorch/aba8c43594a83772281a62a7961c0b6ddcff321d/1?per_page=50&name_filter=distributed%2C%201&mergeEphemeralLF=true ([comment](pytorch#165554 (comment)))
…ytorch#165554) The goal of this PR is to avoid storing the explicit `mesh` Tensor inside each DeviceMesh, and instead compute it on-the-fly when the end user needs it, and try to replace all of its internal usages with `_layout` and the newly-introduced `_global_rank_permutation` Tensor. The name of this attribute is up for debate. The advantage of the `_global_rank_permutation` Tensor is that it is _the same_ Tensor for the root mesh and all its children, so it doesn't need to be copied/reallocated. Pull Request resolved: pytorch#165554 Approved by: https://github.com/fduwjj
…om_ranks (pytorch#165555) The refactoring of DeviceMesh is heavily constrained by the signature of its constructor, which is a public API which contains some "legacy" concepts which we'd love to get rid of, such as an explicit/materialized `mesh` Tensor. In other languages the solution to this would be to add a private overload of the constructor. Python doesn't natively allow this, but in this PR I managed to build something that approximates it. This new private constructor basically only takes `_layout`, `_global_rank_permutation`, and `mesh_dim_names`. With such a constructor we can effectively simplify a lot of callsites and get rid of the `_create_mesh_from_ranks` helper method. That's a good thing because it was instantiating many DeviceMeshes in a for loop, which always felt unnecessary. Pull Request resolved: pytorch#165555 Approved by: https://github.com/fduwjj, https://github.com/fegin ghstack dependencies: pytorch#165554
By adding a few small helpers (e.g., a `splice` method to `_MeshLayout`, and making `_init_process_groups` static and thus stateless) we can substantially shorten the definition of the unflatten method, and help readability. Pull Request resolved: pytorch#165556 Approved by: https://github.com/fduwjj ghstack dependencies: pytorch#165554, pytorch#165555
…ytorch#165554) The goal of this PR is to avoid storing the explicit `mesh` Tensor inside each DeviceMesh, and instead compute it on-the-fly when the end user needs it, and try to replace all of its internal usages with `_layout` and the newly-introduced `_global_rank_permutation` Tensor. The name of this attribute is up for debate. The advantage of the `_global_rank_permutation` Tensor is that it is _the same_ Tensor for the root mesh and all its children, so it doesn't need to be copied/reallocated. Pull Request resolved: pytorch#165554 Approved by: https://github.com/fduwjj
…om_ranks (pytorch#165555) The refactoring of DeviceMesh is heavily constrained by the signature of its constructor, which is a public API which contains some "legacy" concepts which we'd love to get rid of, such as an explicit/materialized `mesh` Tensor. In other languages the solution to this would be to add a private overload of the constructor. Python doesn't natively allow this, but in this PR I managed to build something that approximates it. This new private constructor basically only takes `_layout`, `_global_rank_permutation`, and `mesh_dim_names`. With such a constructor we can effectively simplify a lot of callsites and get rid of the `_create_mesh_from_ranks` helper method. That's a good thing because it was instantiating many DeviceMeshes in a for loop, which always felt unnecessary. Pull Request resolved: pytorch#165555 Approved by: https://github.com/fduwjj, https://github.com/fegin ghstack dependencies: pytorch#165554
By adding a few small helpers (e.g., a `splice` method to `_MeshLayout`, and making `_init_process_groups` static and thus stateless) we can substantially shorten the definition of the unflatten method, and help readability. Pull Request resolved: pytorch#165556 Approved by: https://github.com/fduwjj ghstack dependencies: pytorch#165554, pytorch#165555
This reverts commit 86fd4fc. Reverted pytorch#165556 on behalf of https://github.com/malfet due to Looks like it broke serialization test, see https://hud.pytorch.org/hud/pytorch/pytorch/aba8c43594a83772281a62a7961c0b6ddcff321d/1?per_page=50&name_filter=distributed%2C%201&mergeEphemeralLF=true ([comment](pytorch#165554 (comment)))
…_mesh_from_ranks (pytorch#165555)" This reverts commit 99097b6. Reverted pytorch#165555 on behalf of https://github.com/malfet due to Looks like it broke serialization test, see https://hud.pytorch.org/hud/pytorch/pytorch/aba8c43594a83772281a62a7961c0b6ddcff321d/1?per_page=50&name_filter=distributed%2C%201&mergeEphemeralLF=true ([comment](pytorch#165554 (comment)))
… things (pytorch#165554)" This reverts commit d61a9b8. Reverted pytorch#165554 on behalf of https://github.com/malfet due to Looks like it broke serialization test, see https://hud.pytorch.org/hud/pytorch/pytorch/aba8c43594a83772281a62a7961c0b6ddcff321d/1?per_page=50&name_filter=distributed%2C%201&mergeEphemeralLF=true ([comment](pytorch#165554 (comment)))
…ytorch#165554) The goal of this PR is to avoid storing the explicit `mesh` Tensor inside each DeviceMesh, and instead compute it on-the-fly when the end user needs it, and try to replace all of its internal usages with `_layout` and the newly-introduced `_global_rank_permutation` Tensor. The name of this attribute is up for debate. The advantage of the `_global_rank_permutation` Tensor is that it is _the same_ Tensor for the root mesh and all its children, so it doesn't need to be copied/reallocated. Pull Request resolved: pytorch#165554 Approved by: https://github.com/fduwjj
…om_ranks (pytorch#165555) The refactoring of DeviceMesh is heavily constrained by the signature of its constructor, which is a public API which contains some "legacy" concepts which we'd love to get rid of, such as an explicit/materialized `mesh` Tensor. In other languages the solution to this would be to add a private overload of the constructor. Python doesn't natively allow this, but in this PR I managed to build something that approximates it. This new private constructor basically only takes `_layout`, `_global_rank_permutation`, and `mesh_dim_names`. With such a constructor we can effectively simplify a lot of callsites and get rid of the `_create_mesh_from_ranks` helper method. That's a good thing because it was instantiating many DeviceMeshes in a for loop, which always felt unnecessary. Pull Request resolved: pytorch#165555 Approved by: https://github.com/fduwjj, https://github.com/fegin ghstack dependencies: pytorch#165554
By adding a few small helpers (e.g., a `splice` method to `_MeshLayout`, and making `_init_process_groups` static and thus stateless) we can substantially shorten the definition of the unflatten method, and help readability. Pull Request resolved: pytorch#165556 Approved by: https://github.com/fduwjj ghstack dependencies: pytorch#165554, pytorch#165555
Stack from ghstack (oldest at bottom):
The goal of this PR is to avoid storing the explicit
meshTensor inside each DeviceMesh, and instead compute it on-the-fly when the end user needs it, and try to replace all of its internal usages with_layoutand the newly-introduced_global_rank_permutationTensor. The name of this attribute is up for debate. The advantage of the_global_rank_permutationTensor is that it is the same Tensor for the root mesh and all its children, so it doesn't need to be copied/reallocated.cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci