Skip to content

[DeviceMesh] Simplify unflatten method#165556

Closed
lw wants to merge 10 commits intogh/lw/10/basefrom
gh/lw/10/head
Closed

[DeviceMesh] Simplify unflatten method#165556
lw wants to merge 10 commits intogh/lw/10/basefrom
gh/lw/10/head

Conversation

@lw
Copy link
Contributor

@lw lw commented Oct 15, 2025

Stack from ghstack (oldest at bottom):

By adding a few small helpers (e.g., a splice method to _MeshLayout, and making _init_process_groups static and thus stateless) we can substantially shorten the definition of the unflatten method, and help readability.

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci

[ghstack-poisoned]
@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Oct 15, 2025
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 15, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/165556

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 5591123 with merge base 5d4da26 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

layout = complement(self, world_size)
return _MeshLayout(layout.shape, layout.stride)

def unflatten(self, dim: int, unflatten_sizes: tuple[int, ...]) -> "_MeshLayout":
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know that it was me who insisted for making this a method of _MeshLayout, but I hadn't realized that we then needed to re-extract the sub-layout in order to create the ProcessGroups. This made for some bulky code.

Thus my new suggestion is that we break this monolithic method into two: a composition one (which already exists) and a splice one. This allows the DeviceMesh to achieve the same result as this method in two lines of code, while easily getting access to the intermediate value it needs.

return _get_default_group()

@staticmethod
def _init_process_groups(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could also be extracted to become a global private helper function. I didn't do it as I wanted to keep the diff small.

Comment on lines -196 to -202
sizes = list(self.sizes) # type: ignore[arg-type]
strides = list(self.strides) # type: ignore[arg-type]
unflatten_layout = self[dim].composition(
_MeshLayout(tuple(unflatten_sizes), suffix_product(unflatten_sizes))
)
sizes[dim : dim + 1] = list(unflatten_layout.sizes) # type: ignore[arg-type]
strides[dim : dim + 1] = list(unflatten_layout.strides) # type: ignore[arg-type]
Copy link
Contributor Author

@lw lw Oct 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that the # type: ignore[arg-types] were hiding actual bugs, since .sizes and .strides could be integers, but we can't pass these to list(...)! The new code fixes this thanks to as_tuple.

Comment on lines +1144 to +1149
dim_group_names = self._dim_group_names.copy()
dim_group_names[dim : dim + 1] = self._init_process_groups(
partial_layout,
root_mesh._global_rank_permutation,
mesh_dim_names,
backend_override,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this one so that we can reuse it later if we only want to init backend for some dims not all. And also to make _flatten generic like what Tensor is doing also needs this.

sizes[dim : dim + 1] = list(unflatten_layout.sizes) # type: ignore[arg-type]
strides[dim : dim + 1] = list(unflatten_layout.strides) # type: ignore[arg-type]
def splice(self, start: int, end: int, layout: "_MeshLayout") -> "_MeshLayout":
sizes = list(as_tuple(self.sizes))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with as_tuple and flatten, do we still want to do the proposal of limiting the sizes and strides to be at most 2d?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way I see it, as_tuple and flatten are the tools we need to make the code correct, but moving from IntTuple to tuple[tuple[int, ...], ...] is what will make the type checker help us detect those bugs in the first place.

Copy link
Contributor

@fduwjj fduwjj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this change looks reasonable to me as well. And we can build the shared_state on top of that so that we can also cache the PG for unflatten.

@fduwjj
Copy link
Contributor

fduwjj commented Oct 15, 2025

Shall we land this before #165555? So that I can continue working on unflatten and PG caching? I think we will need to have that PR out by the end of October.

@lw lw added the topic: not user facing topic category label Oct 16, 2025
[ghstack-poisoned]
lw added a commit that referenced this pull request Oct 16, 2025
ghstack-source-id: be7f13b
Pull-Request: #165556
[ghstack-poisoned]
lw added a commit that referenced this pull request Oct 16, 2025
ghstack-source-id: 56a0709
Pull-Request: #165556
[ghstack-poisoned]
lw added a commit that referenced this pull request Oct 16, 2025
ghstack-source-id: c83ac69
Pull-Request: #165556
[ghstack-poisoned]
lw added a commit that referenced this pull request Oct 16, 2025
ghstack-source-id: 3fe55d2
Pull-Request: #165556
[ghstack-poisoned]
lw added a commit that referenced this pull request Oct 16, 2025
ghstack-source-id: 8af73fa
Pull-Request: #165556
@lw lw added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 16, 2025
@fduwjj
Copy link
Contributor

fduwjj commented Oct 16, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

@lw your PR has been reverted as part of the stack under #165554.

@pytorchmergebot pytorchmergebot added Reverted ci-no-td Do not run TD on this PR labels Oct 16, 2025
[ghstack-poisoned]
lw added a commit that referenced this pull request Oct 17, 2025
ghstack-source-id: 6736782
Pull-Request: #165556
[ghstack-poisoned]
lw added a commit that referenced this pull request Oct 17, 2025
ghstack-source-id: a5f80ee
Pull-Request: #165556
[ghstack-poisoned]
lw added a commit that referenced this pull request Oct 17, 2025
ghstack-source-id: 4f51a2b
Pull-Request: #165556
[ghstack-poisoned]
lw added a commit that referenced this pull request Oct 17, 2025
ghstack-source-id: 3ab51d7
Pull-Request: #165556
@lw
Copy link
Contributor Author

lw commented Oct 17, 2025

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 0 checks:

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
By adding a few small helpers (e.g., a `splice` method to `_MeshLayout`, and making `_init_process_groups` static and thus stateless) we can substantially shorten the definition of the unflatten method, and help readability.

Pull Request resolved: pytorch#165556
Approved by: https://github.com/fduwjj
ghstack dependencies: pytorch#165554, pytorch#165555
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
By adding a few small helpers (e.g., a `splice` method to `_MeshLayout`, and making `_init_process_groups` static and thus stateless) we can substantially shorten the definition of the unflatten method, and help readability.

Pull Request resolved: pytorch#165556
Approved by: https://github.com/fduwjj
ghstack dependencies: pytorch#165554, pytorch#165555
zhudada0120 pushed a commit to zhudada0120/pytorch that referenced this pull request Oct 22, 2025
By adding a few small helpers (e.g., a `splice` method to `_MeshLayout`, and making `_init_process_groups` static and thus stateless) we can substantially shorten the definition of the unflatten method, and help readability.

Pull Request resolved: pytorch#165556
Approved by: https://github.com/fduwjj
ghstack dependencies: pytorch#165554, pytorch#165555
zhudada0120 pushed a commit to zhudada0120/pytorch that referenced this pull request Oct 22, 2025
By adding a few small helpers (e.g., a `splice` method to `_MeshLayout`, and making `_init_process_groups` static and thus stateless) we can substantially shorten the definition of the unflatten method, and help readability.

Pull Request resolved: pytorch#165556
Approved by: https://github.com/fduwjj
ghstack dependencies: pytorch#165554, pytorch#165555
@github-actions github-actions bot deleted the gh/lw/10/head branch November 17, 2025 02:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue Reverted topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants