Skip to content

[dtensor] fix flatten mesh dims arg relative to submesh#173790

Closed
IvanKobzarev wants to merge 4 commits intogh/IvanKobzarev/210/basefrom
gh/IvanKobzarev/210/head
Closed

[dtensor] fix flatten mesh dims arg relative to submesh#173790
IvanKobzarev wants to merge 4 commits intogh/IvanKobzarev/210/basefrom
gh/IvanKobzarev/210/head

Conversation

@IvanKobzarev
Copy link
Copy Markdown
Contributor

@IvanKobzarev IvanKobzarev commented Jan 29, 2026

Stack from ghstack (oldest at bottom):

Treat mesh_dims arg in _get_flattened_mesh_by_layout relative to submesh

Test:

python test/distributed/tensor/test_redistribute.py -k test_get_flattened_mesh_by_layout_with_submesh

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Jan 29, 2026

This PR needs a release notes: label

If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Jan 29, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/173790

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (5 Unrelated Failures)

As of commit fe9aa15 with merge base 19449aa (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
Copy Markdown
Contributor

@wconstab wconstab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm- the names indeed should come from the submesh. can you update the test case as @fegin mentioned, to

root_mesh = init_device_mesh((8,), ("world"))
spmd_mesh = root_mesh.unflatten((2, 2, 2), ("pp", "dp", "ep"))["dp", "ep"]

@IvanKobzarev
Copy link
Copy Markdown
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot Bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jan 29, 2026
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

[ghstack-poisoned]
IvanKobzarev added a commit that referenced this pull request Jan 29, 2026
Copy link
Copy Markdown
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix.

[ghstack-poisoned]

Treat mesh_dims arg in _get_flattened_mesh_by_layout relative to submesh

Test:

```
python test/distributed/tensor/test_redistribute.py -k test_get_flattened_mesh_by_layout_with_submesh
```

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 30, 2026
@wconstab
Copy link
Copy Markdown
Contributor

I just did ghstack checkout, spin fixlint, commit, ghstack

will attempt to land asap

@wconstab
Copy link
Copy Markdown
Contributor

wconstab commented Jan 30, 2026

landing internally via D91843715 realized that internal diff needs to be exported and run oss CI too. trying to land oss first instead

@wconstab
Copy link
Copy Markdown
Contributor

@pytorchbot merge -i

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged while ignoring the following 3 checks: pull / linux-docs / build-docs-python-false, inductor / unit-test / inductor-test / test (inductor, 1, 2, linux.g5.4xlarge.nvidia.gpu), inductor / unit-test / inductor-halide-test / test (inductor-halide, 1, 1, linux.12xlarge)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

kapilsh pushed a commit to kapilsh/pytorch that referenced this pull request Feb 2, 2026
Treat mesh_dims arg in _get_flattened_mesh_by_layout relative to submesh

Test:

```
python test/distributed/tensor/test_redistribute.py -k test_get_flattened_mesh_by_layout_with_submesh
```
Pull Request resolved: pytorch#173790
Approved by: https://github.com/wconstab, https://github.com/fegin, https://github.com/jathu

Co-authored-by: Will Constable <whc@meta.com>
@facebook-github-bot
Copy link
Copy Markdown
Contributor

@pytorchbot revert -m="Diff reverted internally" -c="ghfirst"

This Pull Request has been reverted by a revert inside Meta. To re-land this change, please open another pull request, assign the same reviewers, fix the CI failures that caused the revert and make sure that the failing CI runs on the PR by applying the proper ciflow label (e.g., ciflow/trunk).)

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

pytorchmergebot added a commit that referenced this pull request Feb 3, 2026
)"

This reverts commit c7d863a.

Reverted #173790 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](#173790 (comment)))
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@IvanKobzarev your PR has been successfully reverted.

@pytorchmergebot pytorchmergebot added Reverted ci-no-td Do not run TD on this PR labels Feb 3, 2026
@izaitsevfb
Copy link
Copy Markdown
Contributor

sorry for the churn, please feel free to rebase and reland

@wconstab
Copy link
Copy Markdown
Contributor

wconstab commented Feb 4, 2026

I'll take care of this

facebook-github-bot pushed a commit that referenced this pull request Feb 9, 2026
Summary:
Reland of #172610
- includes fixes #173873 (credit bdhirsh) and #173790 (credit IvanKobzarev)

Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaur@redhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes
- Won't optimize reduce_scatters when they shard an uneven sized tensor dim

Details/Limitations
- all_to_all is never merged (left for possible future work, but not obvious how to do it in general)
- reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns
- reduce_scatter and all_gather are only merged when the shards are in left-to-right (ascending) order, since DeviceMesh only supports flattening in ascending order and the mesh ordering impacts correctness.
- groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list
- flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization
- DOES support merging mixed Partial (sum, avg) reductions, using the product of the avg dim sizes to scale after performing a sum reduction on the merged mesh.  Refuses to merge any other combinations of mixed partials.

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

Differential Revision: D92540256
@wconstab
Copy link
Copy Markdown
Contributor

wconstab commented Feb 9, 2026

squashed into #174630

@wconstab wconstab closed this Feb 9, 2026
facebook-github-bot pushed a commit that referenced this pull request Feb 10, 2026
Summary:

Reland of #172610
- includes fixes #173873 (credit bdhirsh) and #173790 (credit IvanKobzarev)

Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaur@redhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes
- Won't optimize reduce_scatters when they shard an uneven sized tensor dim

Details/Limitations
- all_to_all is never merged (left for possible future work, but not obvious how to do it in general)
- reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns
- reduce_scatter and all_gather are only merged when the shards are in left-to-right (ascending) order, since DeviceMesh only supports flattening in ascending order and the mesh ordering impacts correctness.
- groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list
- flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization
- DOES support merging mixed Partial (sum, avg) reductions, using the product of the avg dim sizes to scale after performing a sum reduction on the merged mesh.  Refuses to merge any other combinations of mixed partials.

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

Differential Revision: D92540256
pytorchmergebot pushed a commit that referenced this pull request Feb 10, 2026
Reland of #172610: same code as previous land except:
- includes #173873 (credit @bdhirsh)
- includes #173790 (credit @IvanKobzarev)
- includes #173436
- adds disable contextmanager + test

Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaur@redhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes
- Won't optimize reduce_scatters when they shard an uneven sized tensor dim

Details/Limitations
- all_to_all is never merged (left for possible future work, but not obvious how to do it in general)
- reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns
- reduce_scatter and all_gather are only merged when the shards are in left-to-right (ascending) order, since DeviceMesh only supports flattening in ascending order and the mesh ordering impacts correctness.
- groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list
- flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization
- DOES support merging mixed Partial (sum, avg) reductions, using the product of the avg dim sizes to scale after performing a sum reduction on the merged mesh.  Refuses to merge any other combinations of mixed partials.

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

Differential Revision: D92540256

Pull Request resolved: #174630
Approved by: https://github.com/zpcore
radeksm pushed a commit to radeksm/pytorch that referenced this pull request Feb 20, 2026
libohao1201 pushed a commit to libohao1201/pytorch that referenced this pull request Mar 2, 2026
@github-actions github-actions Bot deleted the gh/IvanKobzarev/210/head branch March 12, 2026 02:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: distributed (dtensor) release notes category Reverted topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants