Skip to content

[DeviceMesh] Prefer using _layout over _mesh for all sorts of things#165554

Closed
lw wants to merge 7 commits intogh/lw/8/basefrom
gh/lw/8/head
Closed

[DeviceMesh] Prefer using _layout over _mesh for all sorts of things#165554
lw wants to merge 7 commits intogh/lw/8/basefrom
gh/lw/8/head

Conversation

@lw
Copy link
Contributor

@lw lw commented Oct 15, 2025

Stack from ghstack (oldest at bottom):

The goal of this PR is to avoid storing the explicit mesh Tensor inside each DeviceMesh, and instead compute it on-the-fly when the end user needs it, and try to replace all of its internal usages with _layout and the newly-introduced _global_rank_permutation Tensor. The name of this attribute is up for debate. The advantage of the _global_rank_permutation Tensor is that it is the same Tensor for the root mesh and all its children, so it doesn't need to be copied/reallocated.

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci

[ghstack-poisoned]
@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Oct 15, 2025
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 15, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/165554

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit c07f577 with merge base 5d4da26 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.


_device_type: str
_mesh: torch.Tensor
_global_rank_permutation: torch.Tensor
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I mentioned, I welcome suggestions for better names.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally don't like the idea of permutation. Shall we call it _global_ranks_alloc?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why alloc (which I guess stand for allocation)? In which sense is this an allocation?

An alternative name could also be _rank_order, which has the benefit of being shorter.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because it means rank allocation? The actual mesh tensor [0, 3, 5, 10] (let's say a random tensor) it means a device allocation. It refers to a device representation. That's why I think allocation might be better? I mean I don't have a strong opinion on the word allocation, and I just want to make sure the name here clearly represents what it means

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about what could be the shortest possible name (to keep things concise), and I like _rank_map, which I find also quite self-descriptive, since this is exactly what we're using this for: to map coordinates to actual ranks. It's also the term we're already using in remap_to_tensor.

Are people ok with this name?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_rank_map sounds good to me.

Comment on lines +203 to +207
self._global_rank_permutation = (
_root_mesh._global_rank_permutation
if _root_mesh is not None
else mesh_tensor.flatten()
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit of a mess, but the next PR in the stack will clean this up.

@@ -258,7 +266,13 @@ def device_type(self) -> str:
@property
def mesh(self) -> torch.Tensor:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is admittedly quite complicated. I think it has been a good exercise to write it out as it helps us understand what are the contents of the mesh Tensor. If we agree on this definition, we can work next on cleaning this up :)

Comment on lines +273 to +274
if full_mesh.size(0) == 1:
return full_mesh[0]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This case here is for DeviceMeshes that do not contain the rank on which they're being instantiated.

In general, I have a hard time understanding in which cases they would be useful, and I'd like to explore the possibility of forbidding such a scenario.

Concretely, this codepath is currently being triggered by get_all_submeshes, thus I'd like to look into whether we can rework the internal usages of that private method and hopefully remove it or improve it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree we can also make get_all_submeshes not rely on this case..

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing is required, can you kindly run DTensor CPU overhead benchmark here: #159169. At least we have an understanding of how much overhead this change will bring into.

@lw lw added the topic: not user facing topic category label Oct 15, 2025
Copy link
Contributor

@fduwjj fduwjj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, this looks make sense to me.. Thanks for continuing helping make device mesh cleaner. Just one question for naming and DTensor CPU overhead.

[ghstack-poisoned]
@lw
Copy link
Contributor Author

lw commented Oct 16, 2025

I ran the CPU benchmark. I looked at the iteration_4/rank0_trace.json files, and measured the e2e time of the first forward on CPU. Before this PR it took 1.526s, after this PR it took 1.521s. Hence no regression.

[ghstack-poisoned]
@lw lw added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 16, 2025
@lw
Copy link
Contributor Author

lw commented Oct 16, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Oct 16, 2025
…om_ranks (#165555)

The refactoring of DeviceMesh is heavily constrained by the signature of its constructor, which is a public API which contains some "legacy" concepts which we'd love to get rid of, such as an explicit/materialized `mesh` Tensor.

In other languages the solution to this would be to add a private overload of the constructor. Python doesn't natively allow this, but in this PR I managed to build something that approximates it.

This new private constructor basically only takes `_layout`, `_global_rank_permutation`, and `mesh_dim_names`.

With such a constructor we can effectively simplify a lot of callsites and get rid of the `_create_mesh_from_ranks` helper method. That's a good thing because it was instantiating many DeviceMeshes in a for loop, which always felt unnecessary.

Pull Request resolved: #165555
Approved by: https://github.com/fduwjj, https://github.com/fegin
ghstack dependencies: #165554
pytorchmergebot pushed a commit that referenced this pull request Oct 16, 2025
By adding a few small helpers (e.g., a `splice` method to `_MeshLayout`, and making `_init_process_groups` static and thus stateless) we can substantially shorten the definition of the unflatten method, and help readability.

Pull Request resolved: #165556
Approved by: https://github.com/fduwjj
ghstack dependencies: #165554, #165555
@fduwjj
Copy link
Contributor

fduwjj commented Oct 16, 2025

somehow this test failed: python test/distributed/test_serialization.py TestSerialization.test_dtensor

lw added 4 commits October 17, 2025 09:51
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #165556

pytorchmergebot pushed a commit that referenced this pull request Oct 17, 2025
…om_ranks (#165555)

The refactoring of DeviceMesh is heavily constrained by the signature of its constructor, which is a public API which contains some "legacy" concepts which we'd love to get rid of, such as an explicit/materialized `mesh` Tensor.

In other languages the solution to this would be to add a private overload of the constructor. Python doesn't natively allow this, but in this PR I managed to build something that approximates it.

This new private constructor basically only takes `_layout`, `_global_rank_permutation`, and `mesh_dim_names`.

With such a constructor we can effectively simplify a lot of callsites and get rid of the `_create_mesh_from_ranks` helper method. That's a good thing because it was instantiating many DeviceMeshes in a for loop, which always felt unnecessary.

Pull Request resolved: #165555
Approved by: https://github.com/fduwjj, https://github.com/fegin
ghstack dependencies: #165554
pytorchmergebot pushed a commit that referenced this pull request Oct 17, 2025
By adding a few small helpers (e.g., a `splice` method to `_MeshLayout`, and making `_init_process_groups` static and thus stateless) we can substantially shorten the definition of the unflatten method, and help readability.

Pull Request resolved: #165556
Approved by: https://github.com/fduwjj
ghstack dependencies: #165554, #165555
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
…ytorch#165554)

The goal of this PR is to avoid storing the explicit `mesh` Tensor inside each DeviceMesh, and instead compute it on-the-fly when the end user needs it, and try to replace all of its internal usages with `_layout` and the newly-introduced `_global_rank_permutation` Tensor. The name of this attribute is up for debate. The advantage of the `_global_rank_permutation` Tensor is that it is _the same_ Tensor for the root mesh and all its children, so it doesn't need to be copied/reallocated.

Pull Request resolved: pytorch#165554
Approved by: https://github.com/fduwjj
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
…om_ranks (pytorch#165555)

The refactoring of DeviceMesh is heavily constrained by the signature of its constructor, which is a public API which contains some "legacy" concepts which we'd love to get rid of, such as an explicit/materialized `mesh` Tensor.

In other languages the solution to this would be to add a private overload of the constructor. Python doesn't natively allow this, but in this PR I managed to build something that approximates it.

This new private constructor basically only takes `_layout`, `_global_rank_permutation`, and `mesh_dim_names`.

With such a constructor we can effectively simplify a lot of callsites and get rid of the `_create_mesh_from_ranks` helper method. That's a good thing because it was instantiating many DeviceMeshes in a for loop, which always felt unnecessary.

Pull Request resolved: pytorch#165555
Approved by: https://github.com/fduwjj, https://github.com/fegin
ghstack dependencies: pytorch#165554
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
By adding a few small helpers (e.g., a `splice` method to `_MeshLayout`, and making `_init_process_groups` static and thus stateless) we can substantially shorten the definition of the unflatten method, and help readability.

Pull Request resolved: pytorch#165556
Approved by: https://github.com/fduwjj
ghstack dependencies: pytorch#165554, pytorch#165555
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
…ytorch#165554)

The goal of this PR is to avoid storing the explicit `mesh` Tensor inside each DeviceMesh, and instead compute it on-the-fly when the end user needs it, and try to replace all of its internal usages with `_layout` and the newly-introduced `_global_rank_permutation` Tensor. The name of this attribute is up for debate. The advantage of the `_global_rank_permutation` Tensor is that it is _the same_ Tensor for the root mesh and all its children, so it doesn't need to be copied/reallocated.

Pull Request resolved: pytorch#165554
Approved by: https://github.com/fduwjj
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
…om_ranks (pytorch#165555)

The refactoring of DeviceMesh is heavily constrained by the signature of its constructor, which is a public API which contains some "legacy" concepts which we'd love to get rid of, such as an explicit/materialized `mesh` Tensor.

In other languages the solution to this would be to add a private overload of the constructor. Python doesn't natively allow this, but in this PR I managed to build something that approximates it.

This new private constructor basically only takes `_layout`, `_global_rank_permutation`, and `mesh_dim_names`.

With such a constructor we can effectively simplify a lot of callsites and get rid of the `_create_mesh_from_ranks` helper method. That's a good thing because it was instantiating many DeviceMeshes in a for loop, which always felt unnecessary.

Pull Request resolved: pytorch#165555
Approved by: https://github.com/fduwjj, https://github.com/fegin
ghstack dependencies: pytorch#165554
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
By adding a few small helpers (e.g., a `splice` method to `_MeshLayout`, and making `_init_process_groups` static and thus stateless) we can substantially shorten the definition of the unflatten method, and help readability.

Pull Request resolved: pytorch#165556
Approved by: https://github.com/fduwjj
ghstack dependencies: pytorch#165554, pytorch#165555
zhudada0120 pushed a commit to zhudada0120/pytorch that referenced this pull request Oct 22, 2025
…ytorch#165554)

The goal of this PR is to avoid storing the explicit `mesh` Tensor inside each DeviceMesh, and instead compute it on-the-fly when the end user needs it, and try to replace all of its internal usages with `_layout` and the newly-introduced `_global_rank_permutation` Tensor. The name of this attribute is up for debate. The advantage of the `_global_rank_permutation` Tensor is that it is _the same_ Tensor for the root mesh and all its children, so it doesn't need to be copied/reallocated.

Pull Request resolved: pytorch#165554
Approved by: https://github.com/fduwjj
zhudada0120 pushed a commit to zhudada0120/pytorch that referenced this pull request Oct 22, 2025
…om_ranks (pytorch#165555)

The refactoring of DeviceMesh is heavily constrained by the signature of its constructor, which is a public API which contains some "legacy" concepts which we'd love to get rid of, such as an explicit/materialized `mesh` Tensor.

In other languages the solution to this would be to add a private overload of the constructor. Python doesn't natively allow this, but in this PR I managed to build something that approximates it.

This new private constructor basically only takes `_layout`, `_global_rank_permutation`, and `mesh_dim_names`.

With such a constructor we can effectively simplify a lot of callsites and get rid of the `_create_mesh_from_ranks` helper method. That's a good thing because it was instantiating many DeviceMeshes in a for loop, which always felt unnecessary.

Pull Request resolved: pytorch#165555
Approved by: https://github.com/fduwjj, https://github.com/fegin
ghstack dependencies: pytorch#165554
zhudada0120 pushed a commit to zhudada0120/pytorch that referenced this pull request Oct 22, 2025
By adding a few small helpers (e.g., a `splice` method to `_MeshLayout`, and making `_init_process_groups` static and thus stateless) we can substantially shorten the definition of the unflatten method, and help readability.

Pull Request resolved: pytorch#165556
Approved by: https://github.com/fduwjj
ghstack dependencies: pytorch#165554, pytorch#165555
zhudada0120 pushed a commit to zhudada0120/pytorch that referenced this pull request Oct 22, 2025
…ytorch#165554)

The goal of this PR is to avoid storing the explicit `mesh` Tensor inside each DeviceMesh, and instead compute it on-the-fly when the end user needs it, and try to replace all of its internal usages with `_layout` and the newly-introduced `_global_rank_permutation` Tensor. The name of this attribute is up for debate. The advantage of the `_global_rank_permutation` Tensor is that it is _the same_ Tensor for the root mesh and all its children, so it doesn't need to be copied/reallocated.

Pull Request resolved: pytorch#165554
Approved by: https://github.com/fduwjj
zhudada0120 pushed a commit to zhudada0120/pytorch that referenced this pull request Oct 22, 2025
…om_ranks (pytorch#165555)

The refactoring of DeviceMesh is heavily constrained by the signature of its constructor, which is a public API which contains some "legacy" concepts which we'd love to get rid of, such as an explicit/materialized `mesh` Tensor.

In other languages the solution to this would be to add a private overload of the constructor. Python doesn't natively allow this, but in this PR I managed to build something that approximates it.

This new private constructor basically only takes `_layout`, `_global_rank_permutation`, and `mesh_dim_names`.

With such a constructor we can effectively simplify a lot of callsites and get rid of the `_create_mesh_from_ranks` helper method. That's a good thing because it was instantiating many DeviceMeshes in a for loop, which always felt unnecessary.

Pull Request resolved: pytorch#165555
Approved by: https://github.com/fduwjj, https://github.com/fegin
ghstack dependencies: pytorch#165554
zhudada0120 pushed a commit to zhudada0120/pytorch that referenced this pull request Oct 22, 2025
By adding a few small helpers (e.g., a `splice` method to `_MeshLayout`, and making `_init_process_groups` static and thus stateless) we can substantially shorten the definition of the unflatten method, and help readability.

Pull Request resolved: pytorch#165556
Approved by: https://github.com/fduwjj
ghstack dependencies: pytorch#165554, pytorch#165555
@github-actions github-actions bot deleted the gh/lw/8/head branch November 17, 2025 02:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue Reverted topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants