Skip to content

[DTensor] Optimize redistribute comms using flattened meshes#172610

Closed
wconstab wants to merge 26 commits intogh/wconstab/504/basefrom
gh/wconstab/504/head
Closed

[DTensor] Optimize redistribute comms using flattened meshes#172610
wconstab wants to merge 26 commits intogh/wconstab/504/basefrom
gh/wconstab/504/head

Conversation

@wconstab
Copy link
Copy Markdown
Contributor

@wconstab wconstab commented Jan 16, 2026

Stack from ghstack (oldest at bottom):

Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See this doc for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces. After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR

  • includes optimization for comms other than all_reduce
  • explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
  • therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
  • Warns once per mesh shape for missing flattened meshes
  • Won't optimize reduce_scatters when they shard an uneven sized tensor dim

Details/Limitations

  • all_to_all is never merged (left for possible future work, but not obvious how to do it in general)
  • reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns
  • reduce_scatter and all_gather are only merged when the shards are in left-to-right (ascending) order, since DeviceMesh only supports flattening in ascending order and the mesh ordering impacts correctness.
  • groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list
  • flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization
  • DOES support merging mixed Partial (sum, avg) reductions, using the product of the avg dim sizes to scale after performing a sum reduction on the merged mesh. Refuses to merge any other combinations of mixed partials.

Fixes #171916

Note: initial attempt used stable sort with a lt
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaur@redhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct becuase sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Jan 16, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/172610

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit 69c2096 with merge base 7754b55 (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

wconstab added a commit that referenced this pull request Jan 16, 2026
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct becuase sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

ghstack-source-id: a3f464e
Pull Request resolved: #172610
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct becuase sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 16, 2026
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct becuase sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

ghstack-source-id: a76552b
Pull Request resolved: #172610
Comment thread torch/distributed/tensor/_redistribute.py Outdated
Comment thread torch/distributed/tensor/_redistribute.py Outdated
Comment thread torch/distributed/tensor/_redistribute.py Outdated
Comment thread torch/distributed/tensor/_redistribute.py Outdated
Comment thread torch/distributed/tensor/_redistribute.py Outdated
Comment thread torch/distributed/tensor/_redistribute.py Outdated
Comment thread torch/distributed/tensor/_redistribute.py Outdated
Comment thread torch/distributed/tensor/_redistribute.py Outdated
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 20, 2026
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

ghstack-source-id: 78ea2a4
Pull Request resolved: #172610
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 20, 2026
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

ghstack-source-id: 0839b38
Pull Request resolved: #172610
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 20, 2026
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

ghstack-source-id: 57eff96
Pull Request resolved: #172610
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 20, 2026
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

ghstack-source-id: 14b6c7f
Pull Request resolved: #172610
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 20, 2026
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

ghstack-source-id: 04626b0
Pull Request resolved: #172610
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 20, 2026
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

ghstack-source-id: 3bf5289
Pull Request resolved: #172610
Comment thread torch/distributed/tensor/_redistribute.py Outdated
Comment thread torch/distributed/tensor/_redistribute.py Outdated
Comment thread torch/distributed/tensor/_redistribute.py
Comment thread torch/distributed/tensor/_redistribute.py Outdated
Comment thread torch/distributed/tensor/_redistribute.py
Comment thread torch/distributed/tensor/_redistribute.py Outdated
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes
- Won't optimize reduce_scatters when they shard an uneven sized tensor dim

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 21, 2026
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

ghstack-source-id: ba70239
Pull Request resolved: #172610
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes
- Won't optimize reduce_scatters when they shard an uneven sized tensor dim

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 21, 2026
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

ghstack-source-id: 89e7336
Pull Request resolved: #172610
pytorchmergebot pushed a commit that referenced this pull request Jan 28, 2026
Comment thread torch/distributed/tensor/_redistribute.py
bdhirsh added a commit that referenced this pull request Jan 30, 2026
…esh dims under compile"

Co-authored with claude. I noticed after #172610 that DTensor's new redistribute call that looks for flattened device meshes can crash under torch.compile/tracing. It looks like `submesh = mesh[dim_names]` will try to construct a fresh DeviceMesh, and ends up calling `.item()` (full stacktrace of the error below).

I'm not 100% familiar with the `DeviceMesh` API's, but claude seemed to find an alternative way to "look for an existing flattened device mesh" that didn't need to call `.item`

Stacktrace:
```
    output = redistribute_local_tensor(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/distributed/tensor/_redistribute.py", line 1452, in redistribute_local_tensor
    optimized_transform_infos = _optimize_transform_infos(
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/distributed/tensor/_redistribute.py", line 475, in _optimize_transform_infos
    flattened, failure_reason = try_create_flattened(group)
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/distributed/tensor/_redistribute.py", line 381, in try_create_flattened
    flattened_mesh = _get_flattened_mesh_by_layout(device_mesh, sorted_mesh_dims)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/distributed/tensor/_redistribute.py", line 189, in _get_flattened_mesh_by_layout
    submesh = mesh[dim_names]
              ~~~~^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/distributed/device_mesh.py", line 669, in __getitem__
    submesh = self._create_sub_mesh(sliced_mesh_layout, mesh_dim_names)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/distributed/device_mesh.py", line 758, in _create_sub_mesh
    res_submesh = DeviceMesh(
                  ^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/distributed/device_mesh.py", line 258, in __init__
    if self._layout.numel() != self.mesh.numel():
                               ^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/distributed/device_mesh.py", line 360, in mesh
    return self._get_mesh_tensor_from_full_mesh(full_mesh)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/distributed/device_mesh.py", line 349, in _get_mesh_tensor_from_full_mesh
    return full_mesh[my_coords[0, 0]]
           ~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/fx/experimental/proxy_tensor.py", line 1625, in __torch_function__
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/_compile.py", line 54, in inner
    return disable_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/_dynamo/eval_frame.py", line 1227, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/_subclasses/functional_tensor.py", line 625, in __torch_dispatch__
    outs_unwrapped = func._op_dk(
                     ^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/_compile.py", line 54, in inner
    return disable_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/_dynamo/eval_frame.py", line 1227, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/utils/_stats.py", line 29, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/fx/experimental/proxy_tensor.py", line 1756, in __torch_dispatch__
    return proxy_call(self, func, self.pre_dispatch, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/fx/experimental/proxy_tensor.py", line 1139, in proxy_call
    raise RuntimeError(
torch._dynamo.exc.BackendCompilerFailed: backend='aot_eager' raised:
RuntimeError: It appears that you're trying to get value out of a tracing tensor with aten._local_scalar_dense.default - erroring out! It's likely that this is caused by data-dependent control flow or similar.  It may be possible to trace this with dynamic shapes; try setting tracing_mode='symbolic' in your make_fx call.
```




[ghstack-poisoned]
pytorch-bot Bot pushed a commit that referenced this pull request Jan 30, 2026
…der compile

Summary:
internal-first land of #173873

Co-authored with claude. I noticed after #172610 that DTensor's new redistribute call that looks for flattened device meshes can crash under torch.compile/tracing. It looks like submesh = mesh[dim_names] will try to construct a fresh DeviceMesh, and ends up calling .item() (full stacktrace of the error below).

I'm not 100% familiar with the DeviceMesh API's, but claude seemed to find an alternative way to "look for an existing flattened device mesh" that didn't need to call .item

Stacktrace:

    output = redistribute_local_tensor(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/distributed/tensor/_redistribute.py", line 1452, in redistribute_local_tensor
    optimized_transform_infos = _optimize_transform_infos(
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/distributed/tensor/_redistribute.py", line 475, in _optimize_transform_infos
    flattened, failure_reason = try_create_flattened(group)
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/distributed/tensor/_redistribute.py", line 381, in try_create_flattened
    flattened_mesh = _get_flattened_mesh_by_layout(device_mesh, sorted_mesh_dims)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/distributed/tensor/_redistribute.py", line 189, in _get_flattened_mesh_by_layout
    submesh = mesh[dim_names]
              ~~~~^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/distributed/device_mesh.py", line 669, in __getitem__
    submesh = self._create_sub_mesh(sliced_mesh_layout, mesh_dim_names)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/distributed/device_mesh.py", line 758, in _create_sub_mesh
    res_submesh = DeviceMesh(
                  ^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/distributed/device_mesh.py", line 258, in __init__
    if self._layout.numel() != self.mesh.numel():
                               ^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/distributed/device_mesh.py", line 360, in mesh
    return self._get_mesh_tensor_from_full_mesh(full_mesh)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/distributed/device_mesh.py", line 349, in _get_mesh_tensor_from_full_mesh
    return full_mesh[my_coords[0, 0]]
           ~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/fx/experimental/proxy_tensor.py", line 1625, in __torch_function__
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/_compile.py", line 54, in inner
    return disable_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/_dynamo/eval_frame.py", line 1227, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/_subclasses/functional_tensor.py", line 625, in __torch_dispatch__
    outs_unwrapped = func._op_dk(
                     ^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/_compile.py", line 54, in inner
    return disable_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/_dynamo/eval_frame.py", line 1227, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/utils/_stats.py", line 29, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/fx/experimental/proxy_tensor.py", line 1756, in __torch_dispatch__
    return proxy_call(self, func, self.pre_dispatch, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/fx/experimental/proxy_tensor.py", line 1139, in proxy_call
    raise RuntimeError(
torch._dynamo.exc.BackendCompilerFailed: backend='aot_eager' raised:
RuntimeError: It appears that you're trying to get value out of a tracing tensor with aten._local_scalar_dense.default - erroring out! It's likely that this is caused by data-dependent control flow or similar.  It may be possible to trace this with dynamic shapes; try setting tracing_mode='symbolic' in your make_fx call.

Test Plan: python test/distributed/tensor/test_dtensor_compile.py -k test_compile_redistribute_flattened_mesh

Differential Revision: D91852906
pytorchmergebot pushed a commit that referenced this pull request Jan 30, 2026
…der compile (#173873)

Co-authored with claude. I noticed after #172610 that DTensor's new redistribute call that looks for flattened device meshes can crash under torch.compile/tracing. It looks like `submesh = mesh[dim_names]` will try to construct a fresh DeviceMesh, and ends up calling `.item()` (full stacktrace of the error below).

I'm not 100% familiar with the `DeviceMesh` API's, but claude seemed to find an alternative way to "look for an existing flattened device mesh" that didn't need to call `.item`

Stacktrace:
```
    output = redistribute_local_tensor(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/distributed/tensor/_redistribute.py", line 1452, in redistribute_local_tensor
    optimized_transform_infos = _optimize_transform_infos(
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/distributed/tensor/_redistribute.py", line 475, in _optimize_transform_infos
    flattened, failure_reason = try_create_flattened(group)
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/distributed/tensor/_redistribute.py", line 381, in try_create_flattened
    flattened_mesh = _get_flattened_mesh_by_layout(device_mesh, sorted_mesh_dims)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/distributed/tensor/_redistribute.py", line 189, in _get_flattened_mesh_by_layout
    submesh = mesh[dim_names]
              ~~~~^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/distributed/device_mesh.py", line 669, in __getitem__
    submesh = self._create_sub_mesh(sliced_mesh_layout, mesh_dim_names)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/distributed/device_mesh.py", line 758, in _create_sub_mesh
    res_submesh = DeviceMesh(
                  ^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/distributed/device_mesh.py", line 258, in __init__
    if self._layout.numel() != self.mesh.numel():
                               ^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/distributed/device_mesh.py", line 360, in mesh
    return self._get_mesh_tensor_from_full_mesh(full_mesh)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/distributed/device_mesh.py", line 349, in _get_mesh_tensor_from_full_mesh
    return full_mesh[my_coords[0, 0]]
           ~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/fx/experimental/proxy_tensor.py", line 1625, in __torch_function__
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/_compile.py", line 54, in inner
    return disable_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/_dynamo/eval_frame.py", line 1227, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/_subclasses/functional_tensor.py", line 625, in __torch_dispatch__
    outs_unwrapped = func._op_dk(
                     ^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/_compile.py", line 54, in inner
    return disable_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/_dynamo/eval_frame.py", line 1227, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/utils/_stats.py", line 29, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/fx/experimental/proxy_tensor.py", line 1756, in __torch_dispatch__
    return proxy_call(self, func, self.pre_dispatch, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/fx/experimental/proxy_tensor.py", line 1139, in proxy_call
    raise RuntimeError(
torch._dynamo.exc.BackendCompilerFailed: backend='aot_eager' raised:
RuntimeError: It appears that you're trying to get value out of a tracing tensor with aten._local_scalar_dense.default - erroring out! It's likely that this is caused by data-dependent control flow or similar.  It may be possible to trace this with dynamic shapes; try setting tracing_mode='symbolic' in your make_fx call.
```

Pull Request resolved: #173873
Approved by: https://github.com/wconstab, https://github.com/fegin
kapilsh pushed a commit to kapilsh/pytorch that referenced this pull request Feb 2, 2026
…#172610)

Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaur@redhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt pytorch#172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes
- Won't optimize reduce_scatters when they shard an uneven sized tensor dim

Details/Limitations
- all_to_all is never merged (left for possible future work, but not obvious how to do it in general)
- reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns
- reduce_scatter and all_gather are only merged when the shards are in left-to-right (ascending) order, since DeviceMesh only supports flattening in ascending order and the mesh ordering impacts correctness.
- groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list
- flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization
- DOES support merging mixed Partial (sum, avg) reductions, using the product of the avg dim sizes to scale after performing a sum reduction on the merged mesh.  Refuses to merge any other combinations of mixed partials.

Fixes pytorch#171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.
Pull Request resolved: pytorch#172610
Approved by: https://github.com/tianyu-l, https://github.com/zpcore
ghstack dependencies: pytorch#173593
kapilsh pushed a commit to kapilsh/pytorch that referenced this pull request Feb 2, 2026
kapilsh pushed a commit to kapilsh/pytorch that referenced this pull request Feb 2, 2026
…der compile (pytorch#173873)

Co-authored with claude. I noticed after pytorch#172610 that DTensor's new redistribute call that looks for flattened device meshes can crash under torch.compile/tracing. It looks like `submesh = mesh[dim_names]` will try to construct a fresh DeviceMesh, and ends up calling `.item()` (full stacktrace of the error below).

I'm not 100% familiar with the `DeviceMesh` API's, but claude seemed to find an alternative way to "look for an existing flattened device mesh" that didn't need to call `.item`

Stacktrace:
```
    output = redistribute_local_tensor(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/distributed/tensor/_redistribute.py", line 1452, in redistribute_local_tensor
    optimized_transform_infos = _optimize_transform_infos(
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/distributed/tensor/_redistribute.py", line 475, in _optimize_transform_infos
    flattened, failure_reason = try_create_flattened(group)
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/distributed/tensor/_redistribute.py", line 381, in try_create_flattened
    flattened_mesh = _get_flattened_mesh_by_layout(device_mesh, sorted_mesh_dims)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/distributed/tensor/_redistribute.py", line 189, in _get_flattened_mesh_by_layout
    submesh = mesh[dim_names]
              ~~~~^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/distributed/device_mesh.py", line 669, in __getitem__
    submesh = self._create_sub_mesh(sliced_mesh_layout, mesh_dim_names)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/distributed/device_mesh.py", line 758, in _create_sub_mesh
    res_submesh = DeviceMesh(
                  ^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/distributed/device_mesh.py", line 258, in __init__
    if self._layout.numel() != self.mesh.numel():
                               ^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/distributed/device_mesh.py", line 360, in mesh
    return self._get_mesh_tensor_from_full_mesh(full_mesh)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/distributed/device_mesh.py", line 349, in _get_mesh_tensor_from_full_mesh
    return full_mesh[my_coords[0, 0]]
           ~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/fx/experimental/proxy_tensor.py", line 1625, in __torch_function__
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/_compile.py", line 54, in inner
    return disable_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/_dynamo/eval_frame.py", line 1227, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/_subclasses/functional_tensor.py", line 625, in __torch_dispatch__
    outs_unwrapped = func._op_dk(
                     ^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/_compile.py", line 54, in inner
    return disable_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/_dynamo/eval_frame.py", line 1227, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/utils/_stats.py", line 29, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/fx/experimental/proxy_tensor.py", line 1756, in __torch_dispatch__
    return proxy_call(self, func, self.pre_dispatch, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/hirsheybar/new2/pytorch/torch/fx/experimental/proxy_tensor.py", line 1139, in proxy_call
    raise RuntimeError(
torch._dynamo.exc.BackendCompilerFailed: backend='aot_eager' raised:
RuntimeError: It appears that you're trying to get value out of a tracing tensor with aten._local_scalar_dense.default - erroring out! It's likely that this is caused by data-dependent control flow or similar.  It may be possible to trace this with dynamic shapes; try setting tracing_mode='symbolic' in your make_fx call.
```

Pull Request resolved: pytorch#173873
Approved by: https://github.com/wconstab, https://github.com/fegin
@facebook-github-bot
Copy link
Copy Markdown
Contributor

@pytorchbot revert -m="Diff reverted internally" -c="ghfirst"

This Pull Request has been reverted by a revert inside Meta. To re-land this change, please open another pull request, assign the same reviewers, fix the CI failures that caused the revert and make sure that the failing CI runs on the PR by applying the proper ciflow label (e.g., ciflow/trunk).)

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Reverting PR 172610 failed

Reason: Command git -C /home/runner/work/pytorch/pytorch revert --no-edit 3c0bd0d117043fb6855494504934009887f6fe65 returned non-zero exit code 1

Auto-merging test/distributed/tensor/test_redistribute.py
CONFLICT (content): Merge conflict in test/distributed/tensor/test_redistribute.py
Auto-merging torch/distributed/tensor/_redistribute.py
error: could not revert 3c0bd0d1170... [DTensor] Optimize redistribute comms using flattened meshes (#172610)
hint: After resolving the conflicts, mark them with
hint: "git add/rm <pathspec>", then run
hint: "git revert --continue".
hint: You can instead skip this commit with "git revert --skip".
hint: To abort and get back to the state before "git revert",
hint: run "git revert --abort".
hint: Disable this message with "git config set advice.mergeConflict false"
Details for Dev Infra team Raised by workflow job

@izaitsevfb
Copy link
Copy Markdown
Contributor

@pytorchbot revert -m="Diff reverted internally" -c="ghfirst"

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

pytorchmergebot added a commit that referenced this pull request Feb 3, 2026
…172610)"

This reverts commit 3c0bd0d.

Reverted #172610 on behalf of https://github.com/izaitsevfb due to Diff reverted internally ([comment](#172610 (comment)))
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@wconstab your PR has been successfully reverted.

@pytorchmergebot pytorchmergebot added Reverted ci-no-td Do not run TD on this PR labels Feb 3, 2026
facebook-github-bot pushed a commit that referenced this pull request Feb 9, 2026
Summary:
Reland of #172610
- includes fixes #173873 (credit bdhirsh) and #173790 (credit IvanKobzarev)

Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaur@redhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes
- Won't optimize reduce_scatters when they shard an uneven sized tensor dim

Details/Limitations
- all_to_all is never merged (left for possible future work, but not obvious how to do it in general)
- reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns
- reduce_scatter and all_gather are only merged when the shards are in left-to-right (ascending) order, since DeviceMesh only supports flattening in ascending order and the mesh ordering impacts correctness.
- groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list
- flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization
- DOES support merging mixed Partial (sum, avg) reductions, using the product of the avg dim sizes to scale after performing a sum reduction on the merged mesh.  Refuses to merge any other combinations of mixed partials.

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

Differential Revision: D92540256
@wconstab
Copy link
Copy Markdown
Contributor Author

wconstab commented Feb 9, 2026

Relanding via #174630

@wconstab wconstab closed this Feb 9, 2026
facebook-github-bot pushed a commit that referenced this pull request Feb 10, 2026
Summary:

Reland of #172610
- includes fixes #173873 (credit bdhirsh) and #173790 (credit IvanKobzarev)

Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaur@redhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes
- Won't optimize reduce_scatters when they shard an uneven sized tensor dim

Details/Limitations
- all_to_all is never merged (left for possible future work, but not obvious how to do it in general)
- reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns
- reduce_scatter and all_gather are only merged when the shards are in left-to-right (ascending) order, since DeviceMesh only supports flattening in ascending order and the mesh ordering impacts correctness.
- groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list
- flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization
- DOES support merging mixed Partial (sum, avg) reductions, using the product of the avg dim sizes to scale after performing a sum reduction on the merged mesh.  Refuses to merge any other combinations of mixed partials.

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

Differential Revision: D92540256
pytorchmergebot pushed a commit that referenced this pull request Feb 10, 2026
Reland of #172610: same code as previous land except:
- includes #173873 (credit @bdhirsh)
- includes #173790 (credit @IvanKobzarev)
- includes #173436
- adds disable contextmanager + test

Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaur@redhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes
- Won't optimize reduce_scatters when they shard an uneven sized tensor dim

Details/Limitations
- all_to_all is never merged (left for possible future work, but not obvious how to do it in general)
- reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns
- reduce_scatter and all_gather are only merged when the shards are in left-to-right (ascending) order, since DeviceMesh only supports flattening in ascending order and the mesh ordering impacts correctness.
- groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list
- flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization
- DOES support merging mixed Partial (sum, avg) reductions, using the product of the avg dim sizes to scale after performing a sum reduction on the merged mesh.  Refuses to merge any other combinations of mixed partials.

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

Differential Revision: D92540256

Pull Request resolved: #174630
Approved by: https://github.com/zpcore
radeksm pushed a commit to radeksm/pytorch that referenced this pull request Feb 20, 2026
libohao1201 pushed a commit to libohao1201/pytorch that referenced this pull request Mar 2, 2026
@github-actions github-actions Bot deleted the gh/wconstab/504/head branch March 12, 2026 02:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants