Skip to content

[DTensor] Optimize redistribute by merging allreduces on flattened meshes#172119

Closed
wconstab wants to merge 7 commits intogh/wconstab/489/basefrom
gh/wconstab/489/head
Closed

[DTensor] Optimize redistribute by merging allreduces on flattened meshes#172119
wconstab wants to merge 7 commits intogh/wconstab/489/basefrom
gh/wconstab/489/head

Conversation

@wconstab
Copy link
Copy Markdown
Contributor

@wconstab wconstab commented Jan 9, 2026

Stack from ghstack (oldest at bottom):

Fixes #171916

When redistributing a DTensor with multiple Partial placements of the same
reduce_op type (e.g., Partial("sum") on dims A and C), this change detects
if a flattened DeviceMesh exists that covers those dimensions and uses a
single allreduce instead of multiple separate ones.

Key changes:

  • Add _FlattenedTransformInfo as duck-type compatible with _TransformInfo,
    allowing the redistribute loop to handle both uniformly via mesh override
  • Add _get_flattened_mesh_by_layout() to query for existing flattened meshes
    using layout comparison rather than name-based lookup
  • Add _optimize_transform_infos_for_flattened_reductions() to group same-type
    reductions and replace with flattened transforms
  • Modify redistribute_local_tensor() to use mesh_to_use pattern so flattened
    transforms flow through the same code path as regular transforms

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
(Partial("sum"), Partial("max"), Partial("sum")), if mesh["A","C"]._flatten()
was called, the A and C sums are merged into 1 allreduce (2 comms instead of 3).

Test Plan:

  • Unit tests for _optimize_transform_infos_for_flattened_reductions using
    fake process group (fast, no NCCL init overhead)
  • Integration tests verifying comm counts with CommDebugMode

-- Claude

…shes

When redistributing a DTensor with multiple Partial placements of the same
reduce_op type (e.g., Partial("sum") on dims A and C), this change detects
if a flattened DeviceMesh exists that covers those dimensions and uses a
single allreduce instead of multiple separate ones.

Key changes:
- Add _FlattenedTransformInfo as duck-type compatible with _TransformInfo,
  allowing the redistribute loop to handle both uniformly via mesh override
- Add _get_flattened_mesh_by_layout() to query for existing flattened meshes
  using layout comparison rather than name-based lookup
- Add _optimize_transform_infos_for_flattened_reductions() to group same-type
  reductions (even non-consecutive) and replace with flattened transforms
- Modify redistribute_local_tensor() to use mesh_to_use pattern so flattened
  transforms flow through the same code path as regular transforms

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
(Partial("sum"), Partial("max"), Partial("sum")), if mesh["A","C"]._flatten()
was called, the A and C sums are merged into 1 allreduce (2 comms instead of 3).

Test Plan:
- Unit tests for _optimize_transform_infos_for_flattened_reductions using
  fake process group (fast, no NCCL init overhead)
- Integration tests verifying comm counts with CommDebugMode

-- Claude

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Jan 9, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/172119

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 4d73f3a with merge base d2237eb (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

wconstab added a commit that referenced this pull request Jan 9, 2026
…shes

When redistributing a DTensor with multiple Partial placements of the same
reduce_op type (e.g., Partial("sum") on dims A and C), this change detects
if a flattened DeviceMesh exists that covers those dimensions and uses a
single allreduce instead of multiple separate ones.

Key changes:
- Add _FlattenedTransformInfo as duck-type compatible with _TransformInfo,
  allowing the redistribute loop to handle both uniformly via mesh override
- Add _get_flattened_mesh_by_layout() to query for existing flattened meshes
  using layout comparison rather than name-based lookup
- Add _optimize_transform_infos_for_flattened_reductions() to group same-type
  reductions (even non-consecutive) and replace with flattened transforms
- Modify redistribute_local_tensor() to use mesh_to_use pattern so flattened
  transforms flow through the same code path as regular transforms

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
(Partial("sum"), Partial("max"), Partial("sum")), if mesh["A","C"]._flatten()
was called, the A and C sums are merged into 1 allreduce (2 comms instead of 3).

Test Plan:
- Unit tests for _optimize_transform_infos_for_flattened_reductions using
  fake process group (fast, no NCCL init overhead)
- Integration tests verifying comm counts with CommDebugMode

-- Claude

ghstack-source-id: 84b1e94
Pull Request resolved: #172119
@wconstab wconstab added the release notes: distributed (dtensor) release notes category label Jan 9, 2026
…lattened meshes"

When redistributing a DTensor with multiple Partial placements of the same
reduce_op type (e.g., Partial("sum") on dims A and C), this change detects
if a flattened DeviceMesh exists that covers those dimensions and uses a
single allreduce instead of multiple separate ones.

Key changes:
- Add _FlattenedTransformInfo as duck-type compatible with _TransformInfo,
  allowing the redistribute loop to handle both uniformly via mesh override
- Add _get_flattened_mesh_by_layout() to query for existing flattened meshes
  using layout comparison rather than name-based lookup
- Add _optimize_transform_infos_for_flattened_reductions() to group same-type
  reductions (even non-consecutive) and replace with flattened transforms
- Modify redistribute_local_tensor() to use mesh_to_use pattern so flattened
  transforms flow through the same code path as regular transforms

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
(Partial("sum"), Partial("max"), Partial("sum")), if mesh["A","C"]._flatten()
was called, the A and C sums are merged into 1 allreduce (2 comms instead of 3).

Test Plan:
- Unit tests for _optimize_transform_infos_for_flattened_reductions using
  fake process group (fast, no NCCL init overhead)
- Integration tests verifying comm counts with CommDebugMode

-- Claude

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 9, 2026
…shes

When redistributing a DTensor with multiple Partial placements of the same
reduce_op type (e.g., Partial("sum") on dims A and C), this change detects
if a flattened DeviceMesh exists that covers those dimensions and uses a
single allreduce instead of multiple separate ones.

Key changes:
- Add _FlattenedTransformInfo as duck-type compatible with _TransformInfo,
  allowing the redistribute loop to handle both uniformly via mesh override
- Add _get_flattened_mesh_by_layout() to query for existing flattened meshes
  using layout comparison rather than name-based lookup
- Add _optimize_transform_infos_for_flattened_reductions() to group same-type
  reductions (even non-consecutive) and replace with flattened transforms
- Modify redistribute_local_tensor() to use mesh_to_use pattern so flattened
  transforms flow through the same code path as regular transforms

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
(Partial("sum"), Partial("max"), Partial("sum")), if mesh["A","C"]._flatten()
was called, the A and C sums are merged into 1 allreduce (2 comms instead of 3).

Test Plan:
- Unit tests for _optimize_transform_infos_for_flattened_reductions using
  fake process group (fast, no NCCL init overhead)
- Integration tests verifying comm counts with CommDebugMode

-- Claude

ghstack-source-id: 2211eae
Pull Request resolved: #172119
…lattened meshes"


Fixes #171916

When redistributing a DTensor with multiple Partial placements of the same
reduce_op type (e.g., Partial("sum") on dims A and C), this change detects
if a flattened DeviceMesh exists that covers those dimensions and uses a
single allreduce instead of multiple separate ones.

Key changes:
- Add _FlattenedTransformInfo as duck-type compatible with _TransformInfo,
  allowing the redistribute loop to handle both uniformly via mesh override
- Add _get_flattened_mesh_by_layout() to query for existing flattened meshes
  using layout comparison rather than name-based lookup
- Add _optimize_transform_infos_for_flattened_reductions() to group same-type
  reductions and replace with flattened transforms
- Modify redistribute_local_tensor() to use mesh_to_use pattern so flattened
  transforms flow through the same code path as regular transforms

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
(Partial("sum"), Partial("max"), Partial("sum")), if mesh["A","C"]._flatten()
was called, the A and C sums are merged into 1 allreduce (2 comms instead of 3).

Test Plan:
- Unit tests for _optimize_transform_infos_for_flattened_reductions using
  fake process group (fast, no NCCL init overhead)
- Integration tests verifying comm counts with CommDebugMode

-- Claude

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 9, 2026
…shes

When redistributing a DTensor with multiple Partial placements of the same
reduce_op type (e.g., Partial("sum") on dims A and C), this change detects
if a flattened DeviceMesh exists that covers those dimensions and uses a
single allreduce instead of multiple separate ones.

Key changes:
- Add _FlattenedTransformInfo as duck-type compatible with _TransformInfo,
  allowing the redistribute loop to handle both uniformly via mesh override
- Add _get_flattened_mesh_by_layout() to query for existing flattened meshes
  using layout comparison rather than name-based lookup
- Add _optimize_transform_infos_for_flattened_reductions() to group same-type
  reductions (even non-consecutive) and replace with flattened transforms
- Modify redistribute_local_tensor() to use mesh_to_use pattern so flattened
  transforms flow through the same code path as regular transforms

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
(Partial("sum"), Partial("max"), Partial("sum")), if mesh["A","C"]._flatten()
was called, the A and C sums are merged into 1 allreduce (2 comms instead of 3).

Test Plan:
- Unit tests for _optimize_transform_infos_for_flattened_reductions using
  fake process group (fast, no NCCL init overhead)
- Integration tests verifying comm counts with CommDebugMode

-- Claude

ghstack-source-id: af466f2
Pull Request resolved: #172119
…lattened meshes"


Fixes #171916

When redistributing a DTensor with multiple Partial placements of the same
reduce_op type (e.g., Partial("sum") on dims A and C), this change detects
if a flattened DeviceMesh exists that covers those dimensions and uses a
single allreduce instead of multiple separate ones.

Key changes:
- Add _FlattenedTransformInfo as duck-type compatible with _TransformInfo,
  allowing the redistribute loop to handle both uniformly via mesh override
- Add _get_flattened_mesh_by_layout() to query for existing flattened meshes
  using layout comparison rather than name-based lookup
- Add _optimize_transform_infos_for_flattened_reductions() to group same-type
  reductions and replace with flattened transforms
- Modify redistribute_local_tensor() to use mesh_to_use pattern so flattened
  transforms flow through the same code path as regular transforms

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
(Partial("sum"), Partial("max"), Partial("sum")), if mesh["A","C"]._flatten()
was called, the A and C sums are merged into 1 allreduce (2 comms instead of 3).

Test Plan:
- Unit tests for _optimize_transform_infos_for_flattened_reductions using
  fake process group (fast, no NCCL init overhead)
- Integration tests verifying comm counts with CommDebugMode

-- Claude

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 9, 2026
…shes

When redistributing a DTensor with multiple Partial placements of the same
reduce_op type (e.g., Partial("sum") on dims A and C), this change detects
if a flattened DeviceMesh exists that covers those dimensions and uses a
single allreduce instead of multiple separate ones.

Key changes:
- Add _FlattenedTransformInfo as duck-type compatible with _TransformInfo,
  allowing the redistribute loop to handle both uniformly via mesh override
- Add _get_flattened_mesh_by_layout() to query for existing flattened meshes
  using layout comparison rather than name-based lookup
- Add _optimize_transform_infos_for_flattened_reductions() to group same-type
  reductions (even non-consecutive) and replace with flattened transforms
- Modify redistribute_local_tensor() to use mesh_to_use pattern so flattened
  transforms flow through the same code path as regular transforms

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
(Partial("sum"), Partial("max"), Partial("sum")), if mesh["A","C"]._flatten()
was called, the A and C sums are merged into 1 allreduce (2 comms instead of 3).

Test Plan:
- Unit tests for _optimize_transform_infos_for_flattened_reductions using
  fake process group (fast, no NCCL init overhead)
- Integration tests verifying comm counts with CommDebugMode

-- Claude

ghstack-source-id: 0e6f6da
Pull Request resolved: #172119
Comment thread torch/distributed/tensor/_redistribute.py Outdated
can_merge = True
for dim in range(last_mesh_dim + 1, next_info.mesh_dim):
if src_placements[dim].is_partial():
# There's a Partial on a skipped dim - can't merge
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why would a partial on a skipped dim prevent merge? PPP-> RPR can still be merged, no?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is true if the partial being skipped is of the same type + same_reduce_op as the source

Copy link
Copy Markdown
Contributor Author

@wconstab wconstab Jan 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, good point. i think i can adjust this to allow skipping a same-type partial and only bail out if its a different type partial.

Updated to include this case.

…lattened meshes"


Fixes #171916

When redistributing a DTensor with multiple Partial placements of the same
reduce_op type (e.g., Partial("sum") on dims A and C), this change detects
if a flattened DeviceMesh exists that covers those dimensions and uses a
single allreduce instead of multiple separate ones.

Key changes:
- Add _FlattenedTransformInfo as duck-type compatible with _TransformInfo,
  allowing the redistribute loop to handle both uniformly via mesh override
- Add _get_flattened_mesh_by_layout() to query for existing flattened meshes
  using layout comparison rather than name-based lookup
- Add _optimize_transform_infos_for_flattened_reductions() to group same-type
  reductions and replace with flattened transforms
- Modify redistribute_local_tensor() to use mesh_to_use pattern so flattened
  transforms flow through the same code path as regular transforms

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
(Partial("sum"), Partial("max"), Partial("sum")), if mesh["A","C"]._flatten()
was called, the A and C sums are merged into 1 allreduce (2 comms instead of 3).

Test Plan:
- Unit tests for _optimize_transform_infos_for_flattened_reductions using
  fake process group (fast, no NCCL init overhead)
- Integration tests verifying comm counts with CommDebugMode

-- Claude

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 10, 2026
…shes

When redistributing a DTensor with multiple Partial placements of the same
reduce_op type (e.g., Partial("sum") on dims A and C), this change detects
if a flattened DeviceMesh exists that covers those dimensions and uses a
single allreduce instead of multiple separate ones.

Key changes:
- Add _FlattenedTransformInfo as duck-type compatible with _TransformInfo,
  allowing the redistribute loop to handle both uniformly via mesh override
- Add _get_flattened_mesh_by_layout() to query for existing flattened meshes
  using layout comparison rather than name-based lookup
- Add _optimize_transform_infos_for_flattened_reductions() to group same-type
  reductions (even non-consecutive) and replace with flattened transforms
- Modify redistribute_local_tensor() to use mesh_to_use pattern so flattened
  transforms flow through the same code path as regular transforms

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
(Partial("sum"), Partial("max"), Partial("sum")), if mesh["A","C"]._flatten()
was called, the A and C sums are merged into 1 allreduce (2 comms instead of 3).

Test Plan:
- Unit tests for _optimize_transform_infos_for_flattened_reductions using
  fake process group (fast, no NCCL init overhead)
- Integration tests verifying comm counts with CommDebugMode

-- Claude

ghstack-source-id: 2a07b00
Pull Request resolved: #172119
Comment on lines +183 to +190
IMPORTANT: Reductions can only be merged if there are no Partial placements
on the mesh dimensions between them. A Partial placement (even if unchanged)
implies a semantic ordering that must be preserved. For example:
(Psum, Pmax, Psum) -> the correct order is sum_A, then max_B, then sum_C
If we merge the sums: sum_{A,C}(x) then max_B gives a different result.

However, non-Partial placements (Replicate, Shard) in between are safe to
skip because they don't impose ordering constraints.
Copy link
Copy Markdown
Member

@zpcore zpcore Jan 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean if there are [Psum->R, S(0)->R, Psum->R] in transform_infos, we can reorder as [Psum->R, Psum->R, S(0)->R]? If so, this becomes we can do stable sort for transform_infos and place all Partial operations in the front or back.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, if you think it's cleaner to write it that way I can change it to do an explicit sorting pass to group like-kind reductions, then do a second pass to merge adjacent like kind reductions.

Comment thread torch/distributed/tensor/_redistribute.py Outdated
…lattened meshes"


Fixes #171916

When redistributing a DTensor with multiple Partial placements of the same
reduce_op type (e.g., Partial("sum") on dims A and C), this change detects
if a flattened DeviceMesh exists that covers those dimensions and uses a
single allreduce instead of multiple separate ones.

Key changes:
- Add _FlattenedTransformInfo as duck-type compatible with _TransformInfo,
  allowing the redistribute loop to handle both uniformly via mesh override
- Add _get_flattened_mesh_by_layout() to query for existing flattened meshes
  using layout comparison rather than name-based lookup
- Add _optimize_transform_infos_for_flattened_reductions() to group same-type
  reductions and replace with flattened transforms
- Modify redistribute_local_tensor() to use mesh_to_use pattern so flattened
  transforms flow through the same code path as regular transforms

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
(Partial("sum"), Partial("max"), Partial("sum")), if mesh["A","C"]._flatten()
was called, the A and C sums are merged into 1 allreduce (2 comms instead of 3).

Test Plan:
- Unit tests for _optimize_transform_infos_for_flattened_reductions using
  fake process group (fast, no NCCL init overhead)
- Integration tests verifying comm counts with CommDebugMode

-- Claude

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 12, 2026
…shes

When redistributing a DTensor with multiple Partial placements of the same
reduce_op type (e.g., Partial("sum") on dims A and C), this change detects
if a flattened DeviceMesh exists that covers those dimensions and uses a
single allreduce instead of multiple separate ones.

Key changes:
- Add _FlattenedTransformInfo as duck-type compatible with _TransformInfo,
  allowing the redistribute loop to handle both uniformly via mesh override
- Add _get_flattened_mesh_by_layout() to query for existing flattened meshes
  using layout comparison rather than name-based lookup
- Add _optimize_transform_infos_for_flattened_reductions() to group same-type
  reductions (even non-consecutive) and replace with flattened transforms
- Modify redistribute_local_tensor() to use mesh_to_use pattern so flattened
  transforms flow through the same code path as regular transforms

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
(Partial("sum"), Partial("max"), Partial("sum")), if mesh["A","C"]._flatten()
was called, the A and C sums are merged into 1 allreduce (2 comms instead of 3).

Test Plan:
- Unit tests for _optimize_transform_infos_for_flattened_reductions using
  fake process group (fast, no NCCL init overhead)
- Integration tests verifying comm counts with CommDebugMode

-- Claude

ghstack-source-id: 990f175
Pull Request resolved: #172119
if flattened_mesh._layout == expected_layout:
return flattened_mesh

return None
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should at the very least warn here and instruct users how to create flattened mesh. Ideally we'd error our (because not having a flattened mesh is a significant perf hit).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok- i think i'll add a warn_once here. do you prefer a warn repetedly?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tbh I prefer noisy warning here, given that it's an easy remedy

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, i put in a warning

others.append((j, next_info))
j += 1
elif next_src == src and not has_blocking_partial(
src, reductions[-1][1].mesh_dim, next_info.mesh_dim
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is kind of weird, the has_blocking_partial is checking all mesh dims between A.mesh_dim and B.mesh_dim, which is not related to ordering in transform_infos.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

um, maybe i misunderstood your point. but i think this is actually crucial. See my other PR to see if it convinces you:
#172277

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, the assumption is true if we follow the default left-to-right order. Do we want to consider the non default order or the graph based redistribution here?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you mean about non-default order? we have only defined a concept of non-default shard ordering, but that ordering specifically applies to the shard placements it names. I was concluding from that that Partials are always in fixed L to R order, which I further documented in the other PR. If you disagree, lmk and we can clarify what the behavior should be for partials.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious if the transforminfos is below:

mesh dim 0, P(sum)->R
mesh dim 2, P(sum)->R
mesh dim 1, P(avg)->R

Based on the code, we will not merge the first two P(sum)->R, because of mesh dim 1's P(avg)->R. Is this true?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got it now. Yea I think that could be a problem- I'll think about it and add a test case for it

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok- discussed offline, this sequence of transform info should be banned as it is not valid to do redistribution of partial in arbitrary order. we should define a partial order that we always use and then this code is OK.

dt = DTensor.from_local(
local_tensor,
mesh,
(Partial("sum"), Partial("max"), Partial("sum")),
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel like there maybe an issue with has_blocking_partial. We may need device mesh with a higher dimension to test. I think we can use LocalTensor, though there is no concept of NCCL domain, but we can still test the tensor output just in case the partial sequence get messed up.

With LocalTensor, we probably can mock redistribute_local_tensor to take the a long sequence of made up transform_infos and compare the redistributed tensor with merged partial transform_infos.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to worry. I think it's correct for this PR with the default non-graph based redistribution as in #172119 (comment). I can follow up and verify if we want to extend the support for arbitrary transform_infos sequences.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't understand these comments. maybe we can discuss offline to clarify.

…lattened meshes"


Fixes #171916

When redistributing a DTensor with multiple Partial placements of the same
reduce_op type (e.g., Partial("sum") on dims A and C), this change detects
if a flattened DeviceMesh exists that covers those dimensions and uses a
single allreduce instead of multiple separate ones.

Key changes:
- Add _FlattenedTransformInfo as duck-type compatible with _TransformInfo,
  allowing the redistribute loop to handle both uniformly via mesh override
- Add _get_flattened_mesh_by_layout() to query for existing flattened meshes
  using layout comparison rather than name-based lookup
- Add _optimize_transform_infos_for_flattened_reductions() to group same-type
  reductions and replace with flattened transforms
- Modify redistribute_local_tensor() to use mesh_to_use pattern so flattened
  transforms flow through the same code path as regular transforms

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
(Partial("sum"), Partial("max"), Partial("sum")), if mesh["A","C"]._flatten()
was called, the A and C sums are merged into 1 allreduce (2 comms instead of 3).

Test Plan:
- Unit tests for _optimize_transform_infos_for_flattened_reductions using
  fake process group (fast, no NCCL init overhead)
- Integration tests verifying comm counts with CommDebugMode

-- Claude

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 13, 2026
…shes

When redistributing a DTensor with multiple Partial placements of the same
reduce_op type (e.g., Partial("sum") on dims A and C), this change detects
if a flattened DeviceMesh exists that covers those dimensions and uses a
single allreduce instead of multiple separate ones.

Key changes:
- Add _FlattenedTransformInfo as duck-type compatible with _TransformInfo,
  allowing the redistribute loop to handle both uniformly via mesh override
- Add _get_flattened_mesh_by_layout() to query for existing flattened meshes
  using layout comparison rather than name-based lookup
- Add _optimize_transform_infos_for_flattened_reductions() to group same-type
  reductions (even non-consecutive) and replace with flattened transforms
- Modify redistribute_local_tensor() to use mesh_to_use pattern so flattened
  transforms flow through the same code path as regular transforms

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
(Partial("sum"), Partial("max"), Partial("sum")), if mesh["A","C"]._flatten()
was called, the A and C sums are merged into 1 allreduce (2 comms instead of 3).

Test Plan:
- Unit tests for _optimize_transform_infos_for_flattened_reductions using
  fake process group (fast, no NCCL init overhead)
- Integration tests verifying comm counts with CommDebugMode

-- Claude

ghstack-source-id: 3c178d0
Pull Request resolved: #172119
Copy link
Copy Markdown
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this fix addresses my concern in #171913 (comment)

In fact, I'm not sure if this is an issue to be solved at DTensor level at all. Do you think we can first detect and error out, before merging a solution which may not be sufficient?

I think it's more urgent to solve the problem for application API, e.g. replicate (DDP) + Sequence Parallel. cc @weifengpy

@wconstab
Copy link
Copy Markdown
Contributor Author

wconstab commented Jan 14, 2026

@tianyu-l i made a doc to discuss the tradeoffs: https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?usp=sharing

tldr, i think this PR is at best a stopgap solution that does improve performance, from that angle i'm ok with landing it, but i don't love it. I advocate for a bigger change to both (1) devicemesh/processgroup, and (2) redistribution planner in DTensor

wconstab added a commit that referenced this pull request Jan 16, 2026
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaur@redhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct becuase sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 16, 2026
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct becuase sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

ghstack-source-id: a3f464e
Pull Request resolved: #172610
wconstab added a commit that referenced this pull request Jan 16, 2026
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct becuase sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 16, 2026
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct becuase sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

ghstack-source-id: a76552b
Pull Request resolved: #172610
wconstab added a commit that referenced this pull request Jan 21, 2026
…g flattened meshes"


Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes
- Won't optimize reduce_scatters when they shard an uneven sized tensor dim

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 21, 2026
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes
- Won't optimize reduce_scatters when they shard an uneven sized tensor dim

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 22, 2026
…g flattened meshes"


Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes
- Won't optimize reduce_scatters when they shard an uneven sized tensor dim

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 22, 2026
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes
- Won't optimize reduce_scatters when they shard an uneven sized tensor dim

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 22, 2026
…g flattened meshes"


Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes
- Won't optimize reduce_scatters when they shard an uneven sized tensor dim

Limitations
- all_to_all is never merged (left for possible future work, but not obvious how to do it in general)
- reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns
- groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list
- flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 22, 2026
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes
- Won't optimize reduce_scatters when they shard an uneven sized tensor dim

Limitations
- all_to_all is never merged (left for possible future work, but not obvious how to do it in general)
- reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns
- groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list
- flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 22, 2026
…g flattened meshes"


Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes
- Won't optimize reduce_scatters when they shard an uneven sized tensor dim

Limitations
- all_to_all is never merged (left for possible future work, but not obvious how to do it in general)
- reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns
- groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list
- flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 22, 2026
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes
- Won't optimize reduce_scatters when they shard an uneven sized tensor dim

Limitations
- all_to_all is never merged (left for possible future work, but not obvious how to do it in general)
- reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns
- groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list
- flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 24, 2026
…g flattened meshes"


Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes
- Won't optimize reduce_scatters when they shard an uneven sized tensor dim

Limitations
- all_to_all is never merged (left for possible future work, but not obvious how to do it in general)
- reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns
- groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list
- flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 24, 2026
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes
- Won't optimize reduce_scatters when they shard an uneven sized tensor dim

Limitations
- all_to_all is never merged (left for possible future work, but not obvious how to do it in general)
- reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns
- groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list
- flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 24, 2026
…g flattened meshes"


Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes
- Won't optimize reduce_scatters when they shard an uneven sized tensor dim

Limitations
- all_to_all is never merged (left for possible future work, but not obvious how to do it in general)
- reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns
- groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list
- flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 24, 2026
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes
- Won't optimize reduce_scatters when they shard an uneven sized tensor dim

Limitations
- all_to_all is never merged (left for possible future work, but not obvious how to do it in general)
- reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns
- groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list
- flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 24, 2026
…g flattened meshes"


Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes
- Won't optimize reduce_scatters when they shard an uneven sized tensor dim

Limitations
- all_to_all is never merged (left for possible future work, but not obvious how to do it in general)
- reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns
- groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list
- flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 24, 2026
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes
- Won't optimize reduce_scatters when they shard an uneven sized tensor dim

Limitations
- all_to_all is never merged (left for possible future work, but not obvious how to do it in general)
- reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns
- groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list
- flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 24, 2026
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

ghstack-source-id: 48bb7c0
Pull Request resolved: #172610
wconstab added a commit that referenced this pull request Jan 24, 2026
…g flattened meshes"


Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes
- Won't optimize reduce_scatters when they shard an uneven sized tensor dim

Limitations
- all_to_all is never merged (left for possible future work, but not obvious how to do it in general)
- reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns
- groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list
- flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 24, 2026
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes
- Won't optimize reduce_scatters when they shard an uneven sized tensor dim

Limitations
- all_to_all is never merged (left for possible future work, but not obvious how to do it in general)
- reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns
- groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list
- flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 24, 2026
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

ghstack-source-id: 2b47663
Pull Request resolved: #172610
wconstab added a commit that referenced this pull request Jan 24, 2026
…g flattened meshes"


Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes
- Won't optimize reduce_scatters when they shard an uneven sized tensor dim

Details/Limitations
- all_to_all is never merged (left for possible future work, but not obvious how to do it in general)
- reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns
- reduce_scatter is only merged when the shards are in left-to-right (ascending) order, since DeviceMesh only supports flattening in ascending order and the mesh ordering impacts correctness for merged reduce_scatter
- groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list
- flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization
- DOES support merging mixed Partial (sum, avg) reductions, using the product of the avg dim sizes to scale after performing a sum reduction on the merged mesh.  Refuses to merge any other combinations of mixed partials.

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 24, 2026
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes
- Won't optimize reduce_scatters when they shard an uneven sized tensor dim

Details/Limitations
- all_to_all is never merged (left for possible future work, but not obvious how to do it in general)
- reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns
- reduce_scatter is only merged when the shards are in left-to-right (ascending) order, since DeviceMesh only supports flattening in ascending order and the mesh ordering impacts correctness for merged reduce_scatter
- groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list
- flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization
- DOES support merging mixed Partial (sum, avg) reductions, using the product of the avg dim sizes to scale after performing a sum reduction on the merged mesh.  Refuses to merge any other combinations of mixed partials.

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 24, 2026
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

ghstack-source-id: d44f8bb
Pull Request resolved: #172610
wconstab added a commit that referenced this pull request Jan 26, 2026
…g flattened meshes"


Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes
- Won't optimize reduce_scatters when they shard an uneven sized tensor dim

Details/Limitations
- all_to_all is never merged (left for possible future work, but not obvious how to do it in general)
- reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns
- reduce_scatter is only merged when the shards are in left-to-right (ascending) order, since DeviceMesh only supports flattening in ascending order and the mesh ordering impacts correctness for merged reduce_scatter
- groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list
- flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization
- DOES support merging mixed Partial (sum, avg) reductions, using the product of the avg dim sizes to scale after performing a sum reduction on the merged mesh.  Refuses to merge any other combinations of mixed partials.

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 26, 2026
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes
- Won't optimize reduce_scatters when they shard an uneven sized tensor dim

Details/Limitations
- all_to_all is never merged (left for possible future work, but not obvious how to do it in general)
- reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns
- reduce_scatter is only merged when the shards are in left-to-right (ascending) order, since DeviceMesh only supports flattening in ascending order and the mesh ordering impacts correctness for merged reduce_scatter
- groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list
- flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization
- DOES support merging mixed Partial (sum, avg) reductions, using the product of the avg dim sizes to scale after performing a sum reduction on the merged mesh.  Refuses to merge any other combinations of mixed partials.

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 26, 2026
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

ghstack-source-id: 8acd9a8
Pull Request resolved: #172610
wconstab added a commit that referenced this pull request Jan 26, 2026
…g flattened meshes"


Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes
- Won't optimize reduce_scatters when they shard an uneven sized tensor dim

Details/Limitations
- all_to_all is never merged (left for possible future work, but not obvious how to do it in general)
- reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns
- reduce_scatter and all_gather are only merged when the shards are in left-to-right (ascending) order, since DeviceMesh only supports flattening in ascending order and the mesh ordering impacts correctness.
- groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list
- flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization
- DOES support merging mixed Partial (sum, avg) reductions, using the product of the avg dim sizes to scale after performing a sum reduction on the merged mesh.  Refuses to merge any other combinations of mixed partials.

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 26, 2026
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes
- Won't optimize reduce_scatters when they shard an uneven sized tensor dim

Details/Limitations
- all_to_all is never merged (left for possible future work, but not obvious how to do it in general)
- reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns
- reduce_scatter and all_gather are only merged when the shards are in left-to-right (ascending) order, since DeviceMesh only supports flattening in ascending order and the mesh ordering impacts correctness.
- groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list
- flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization
- DOES support merging mixed Partial (sum, avg) reductions, using the product of the avg dim sizes to scale after performing a sum reduction on the merged mesh.  Refuses to merge any other combinations of mixed partials.

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 27, 2026
…g flattened meshes"


Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes
- Won't optimize reduce_scatters when they shard an uneven sized tensor dim

Details/Limitations
- all_to_all is never merged (left for possible future work, but not obvious how to do it in general)
- reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns
- reduce_scatter and all_gather are only merged when the shards are in left-to-right (ascending) order, since DeviceMesh only supports flattening in ascending order and the mesh ordering impacts correctness.
- groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list
- flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization
- DOES support merging mixed Partial (sum, avg) reductions, using the product of the avg dim sizes to scale after performing a sum reduction on the merged mesh.  Refuses to merge any other combinations of mixed partials.

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 27, 2026
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes
- Won't optimize reduce_scatters when they shard an uneven sized tensor dim

Details/Limitations
- all_to_all is never merged (left for possible future work, but not obvious how to do it in general)
- reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns
- reduce_scatter and all_gather are only merged when the shards are in left-to-right (ascending) order, since DeviceMesh only supports flattening in ascending order and the mesh ordering impacts correctness.
- groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list
- flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization
- DOES support merging mixed Partial (sum, avg) reductions, using the product of the avg dim sizes to scale after performing a sum reduction on the merged mesh.  Refuses to merge any other combinations of mixed partials.

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jan 27, 2026
…g flattened meshes"


Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes
- Won't optimize reduce_scatters when they shard an uneven sized tensor dim

Details/Limitations
- all_to_all is never merged (left for possible future work, but not obvious how to do it in general)
- reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns
- reduce_scatter and all_gather are only merged when the shards are in left-to-right (ascending) order, since DeviceMesh only supports flattening in ascending order and the mesh ordering impacts correctness.
- groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list
- flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization
- DOES support merging mixed Partial (sum, avg) reductions, using the product of the avg dim sizes to scale after performing a sum reduction on the merged mesh.  Refuses to merge any other combinations of mixed partials.

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants