[DeviceMesh] Simplify unflatten method#165556
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/165556
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 5591123 with merge base 5d4da26 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| layout = complement(self, world_size) | ||
| return _MeshLayout(layout.shape, layout.stride) | ||
|
|
||
| def unflatten(self, dim: int, unflatten_sizes: tuple[int, ...]) -> "_MeshLayout": |
There was a problem hiding this comment.
I know that it was me who insisted for making this a method of _MeshLayout, but I hadn't realized that we then needed to re-extract the sub-layout in order to create the ProcessGroups. This made for some bulky code.
Thus my new suggestion is that we break this monolithic method into two: a composition one (which already exists) and a splice one. This allows the DeviceMesh to achieve the same result as this method in two lines of code, while easily getting access to the intermediate value it needs.
| return _get_default_group() | ||
|
|
||
| @staticmethod | ||
| def _init_process_groups( |
There was a problem hiding this comment.
This could also be extracted to become a global private helper function. I didn't do it as I wanted to keep the diff small.
| sizes = list(self.sizes) # type: ignore[arg-type] | ||
| strides = list(self.strides) # type: ignore[arg-type] | ||
| unflatten_layout = self[dim].composition( | ||
| _MeshLayout(tuple(unflatten_sizes), suffix_product(unflatten_sizes)) | ||
| ) | ||
| sizes[dim : dim + 1] = list(unflatten_layout.sizes) # type: ignore[arg-type] | ||
| strides[dim : dim + 1] = list(unflatten_layout.strides) # type: ignore[arg-type] |
There was a problem hiding this comment.
Note that the # type: ignore[arg-types] were hiding actual bugs, since .sizes and .strides could be integers, but we can't pass these to list(...)! The new code fixes this thanks to as_tuple.
| dim_group_names = self._dim_group_names.copy() | ||
| dim_group_names[dim : dim + 1] = self._init_process_groups( | ||
| partial_layout, | ||
| root_mesh._global_rank_permutation, | ||
| mesh_dim_names, | ||
| backend_override, |
There was a problem hiding this comment.
I like this one so that we can reuse it later if we only want to init backend for some dims not all. And also to make _flatten generic like what Tensor is doing also needs this.
| sizes[dim : dim + 1] = list(unflatten_layout.sizes) # type: ignore[arg-type] | ||
| strides[dim : dim + 1] = list(unflatten_layout.strides) # type: ignore[arg-type] | ||
| def splice(self, start: int, end: int, layout: "_MeshLayout") -> "_MeshLayout": | ||
| sizes = list(as_tuple(self.sizes)) |
There was a problem hiding this comment.
with as_tuple and flatten, do we still want to do the proposal of limiting the sizes and strides to be at most 2d?
There was a problem hiding this comment.
The way I see it, as_tuple and flatten are the tools we need to make the code correct, but moving from IntTuple to tuple[tuple[int, ...], ...] is what will make the type checker help us detect those bugs in the first place.
fduwjj
left a comment
There was a problem hiding this comment.
this change looks reasonable to me as well. And we can build the shared_state on top of that so that we can also cache the PG for unflatten.
|
Shall we land this before #165555? So that I can continue working on unflatten and PG caching? I think we will need to have that PR out by the end of October. |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This reverts commit 86fd4fc. Reverted #165556 on behalf of https://github.com/malfet due to Looks like it broke serialization test, see https://hud.pytorch.org/hud/pytorch/pytorch/aba8c43594a83772281a62a7961c0b6ddcff321d/1?per_page=50&name_filter=distributed%2C%201&mergeEphemeralLF=true ([comment](#165554 (comment)))
|
@pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 0 checks: Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
By adding a few small helpers (e.g., a `splice` method to `_MeshLayout`, and making `_init_process_groups` static and thus stateless) we can substantially shorten the definition of the unflatten method, and help readability. Pull Request resolved: pytorch#165556 Approved by: https://github.com/fduwjj ghstack dependencies: pytorch#165554, pytorch#165555
This reverts commit 86fd4fc. Reverted pytorch#165556 on behalf of https://github.com/malfet due to Looks like it broke serialization test, see https://hud.pytorch.org/hud/pytorch/pytorch/aba8c43594a83772281a62a7961c0b6ddcff321d/1?per_page=50&name_filter=distributed%2C%201&mergeEphemeralLF=true ([comment](pytorch#165554 (comment)))
By adding a few small helpers (e.g., a `splice` method to `_MeshLayout`, and making `_init_process_groups` static and thus stateless) we can substantially shorten the definition of the unflatten method, and help readability. Pull Request resolved: pytorch#165556 Approved by: https://github.com/fduwjj ghstack dependencies: pytorch#165554, pytorch#165555
By adding a few small helpers (e.g., a `splice` method to `_MeshLayout`, and making `_init_process_groups` static and thus stateless) we can substantially shorten the definition of the unflatten method, and help readability. Pull Request resolved: pytorch#165556 Approved by: https://github.com/fduwjj ghstack dependencies: pytorch#165554, pytorch#165555
This reverts commit 86fd4fc. Reverted pytorch#165556 on behalf of https://github.com/malfet due to Looks like it broke serialization test, see https://hud.pytorch.org/hud/pytorch/pytorch/aba8c43594a83772281a62a7961c0b6ddcff321d/1?per_page=50&name_filter=distributed%2C%201&mergeEphemeralLF=true ([comment](pytorch#165554 (comment)))
By adding a few small helpers (e.g., a `splice` method to `_MeshLayout`, and making `_init_process_groups` static and thus stateless) we can substantially shorten the definition of the unflatten method, and help readability. Pull Request resolved: pytorch#165556 Approved by: https://github.com/fduwjj ghstack dependencies: pytorch#165554, pytorch#165555
Stack from ghstack (oldest at bottom):
By adding a few small helpers (e.g., a
splicemethod to_MeshLayout, and making_init_process_groupsstatic and thus stateless) we can substantially shorten the definition of the unflatten method, and help readability.cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci