[DTensor] Optimize redistribute by merging allreduces on flattened meshes#172119
[DTensor] Optimize redistribute by merging allreduces on flattened meshes#172119wconstab wants to merge 7 commits intogh/wconstab/489/basefrom
Conversation
…shes
When redistributing a DTensor with multiple Partial placements of the same
reduce_op type (e.g., Partial("sum") on dims A and C), this change detects
if a flattened DeviceMesh exists that covers those dimensions and uses a
single allreduce instead of multiple separate ones.
Key changes:
- Add _FlattenedTransformInfo as duck-type compatible with _TransformInfo,
allowing the redistribute loop to handle both uniformly via mesh override
- Add _get_flattened_mesh_by_layout() to query for existing flattened meshes
using layout comparison rather than name-based lookup
- Add _optimize_transform_infos_for_flattened_reductions() to group same-type
reductions (even non-consecutive) and replace with flattened transforms
- Modify redistribute_local_tensor() to use mesh_to_use pattern so flattened
transforms flow through the same code path as regular transforms
Example: For a (2,2,2) mesh with dims (A,B,C) and placements
(Partial("sum"), Partial("max"), Partial("sum")), if mesh["A","C"]._flatten()
was called, the A and C sums are merged into 1 allreduce (2 comms instead of 3).
Test Plan:
- Unit tests for _optimize_transform_infos_for_flattened_reductions using
fake process group (fast, no NCCL init overhead)
- Integration tests verifying comm counts with CommDebugMode
-- Claude
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/172119
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 4d73f3a with merge base d2237eb ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…shes
When redistributing a DTensor with multiple Partial placements of the same
reduce_op type (e.g., Partial("sum") on dims A and C), this change detects
if a flattened DeviceMesh exists that covers those dimensions and uses a
single allreduce instead of multiple separate ones.
Key changes:
- Add _FlattenedTransformInfo as duck-type compatible with _TransformInfo,
allowing the redistribute loop to handle both uniformly via mesh override
- Add _get_flattened_mesh_by_layout() to query for existing flattened meshes
using layout comparison rather than name-based lookup
- Add _optimize_transform_infos_for_flattened_reductions() to group same-type
reductions (even non-consecutive) and replace with flattened transforms
- Modify redistribute_local_tensor() to use mesh_to_use pattern so flattened
transforms flow through the same code path as regular transforms
Example: For a (2,2,2) mesh with dims (A,B,C) and placements
(Partial("sum"), Partial("max"), Partial("sum")), if mesh["A","C"]._flatten()
was called, the A and C sums are merged into 1 allreduce (2 comms instead of 3).
Test Plan:
- Unit tests for _optimize_transform_infos_for_flattened_reductions using
fake process group (fast, no NCCL init overhead)
- Integration tests verifying comm counts with CommDebugMode
-- Claude
ghstack-source-id: 84b1e94
Pull Request resolved: #172119
…lattened meshes"
When redistributing a DTensor with multiple Partial placements of the same
reduce_op type (e.g., Partial("sum") on dims A and C), this change detects
if a flattened DeviceMesh exists that covers those dimensions and uses a
single allreduce instead of multiple separate ones.
Key changes:
- Add _FlattenedTransformInfo as duck-type compatible with _TransformInfo,
allowing the redistribute loop to handle both uniformly via mesh override
- Add _get_flattened_mesh_by_layout() to query for existing flattened meshes
using layout comparison rather than name-based lookup
- Add _optimize_transform_infos_for_flattened_reductions() to group same-type
reductions (even non-consecutive) and replace with flattened transforms
- Modify redistribute_local_tensor() to use mesh_to_use pattern so flattened
transforms flow through the same code path as regular transforms
Example: For a (2,2,2) mesh with dims (A,B,C) and placements
(Partial("sum"), Partial("max"), Partial("sum")), if mesh["A","C"]._flatten()
was called, the A and C sums are merged into 1 allreduce (2 comms instead of 3).
Test Plan:
- Unit tests for _optimize_transform_infos_for_flattened_reductions using
fake process group (fast, no NCCL init overhead)
- Integration tests verifying comm counts with CommDebugMode
-- Claude
[ghstack-poisoned]
…shes
When redistributing a DTensor with multiple Partial placements of the same
reduce_op type (e.g., Partial("sum") on dims A and C), this change detects
if a flattened DeviceMesh exists that covers those dimensions and uses a
single allreduce instead of multiple separate ones.
Key changes:
- Add _FlattenedTransformInfo as duck-type compatible with _TransformInfo,
allowing the redistribute loop to handle both uniformly via mesh override
- Add _get_flattened_mesh_by_layout() to query for existing flattened meshes
using layout comparison rather than name-based lookup
- Add _optimize_transform_infos_for_flattened_reductions() to group same-type
reductions (even non-consecutive) and replace with flattened transforms
- Modify redistribute_local_tensor() to use mesh_to_use pattern so flattened
transforms flow through the same code path as regular transforms
Example: For a (2,2,2) mesh with dims (A,B,C) and placements
(Partial("sum"), Partial("max"), Partial("sum")), if mesh["A","C"]._flatten()
was called, the A and C sums are merged into 1 allreduce (2 comms instead of 3).
Test Plan:
- Unit tests for _optimize_transform_infos_for_flattened_reductions using
fake process group (fast, no NCCL init overhead)
- Integration tests verifying comm counts with CommDebugMode
-- Claude
ghstack-source-id: 2211eae
Pull Request resolved: #172119
…lattened meshes" Fixes #171916 When redistributing a DTensor with multiple Partial placements of the same reduce_op type (e.g., Partial("sum") on dims A and C), this change detects if a flattened DeviceMesh exists that covers those dimensions and uses a single allreduce instead of multiple separate ones. Key changes: - Add _FlattenedTransformInfo as duck-type compatible with _TransformInfo, allowing the redistribute loop to handle both uniformly via mesh override - Add _get_flattened_mesh_by_layout() to query for existing flattened meshes using layout comparison rather than name-based lookup - Add _optimize_transform_infos_for_flattened_reductions() to group same-type reductions and replace with flattened transforms - Modify redistribute_local_tensor() to use mesh_to_use pattern so flattened transforms flow through the same code path as regular transforms Example: For a (2,2,2) mesh with dims (A,B,C) and placements (Partial("sum"), Partial("max"), Partial("sum")), if mesh["A","C"]._flatten() was called, the A and C sums are merged into 1 allreduce (2 comms instead of 3). Test Plan: - Unit tests for _optimize_transform_infos_for_flattened_reductions using fake process group (fast, no NCCL init overhead) - Integration tests verifying comm counts with CommDebugMode -- Claude [ghstack-poisoned]
…shes
When redistributing a DTensor with multiple Partial placements of the same
reduce_op type (e.g., Partial("sum") on dims A and C), this change detects
if a flattened DeviceMesh exists that covers those dimensions and uses a
single allreduce instead of multiple separate ones.
Key changes:
- Add _FlattenedTransformInfo as duck-type compatible with _TransformInfo,
allowing the redistribute loop to handle both uniformly via mesh override
- Add _get_flattened_mesh_by_layout() to query for existing flattened meshes
using layout comparison rather than name-based lookup
- Add _optimize_transform_infos_for_flattened_reductions() to group same-type
reductions (even non-consecutive) and replace with flattened transforms
- Modify redistribute_local_tensor() to use mesh_to_use pattern so flattened
transforms flow through the same code path as regular transforms
Example: For a (2,2,2) mesh with dims (A,B,C) and placements
(Partial("sum"), Partial("max"), Partial("sum")), if mesh["A","C"]._flatten()
was called, the A and C sums are merged into 1 allreduce (2 comms instead of 3).
Test Plan:
- Unit tests for _optimize_transform_infos_for_flattened_reductions using
fake process group (fast, no NCCL init overhead)
- Integration tests verifying comm counts with CommDebugMode
-- Claude
ghstack-source-id: af466f2
Pull Request resolved: #172119
…lattened meshes" Fixes #171916 When redistributing a DTensor with multiple Partial placements of the same reduce_op type (e.g., Partial("sum") on dims A and C), this change detects if a flattened DeviceMesh exists that covers those dimensions and uses a single allreduce instead of multiple separate ones. Key changes: - Add _FlattenedTransformInfo as duck-type compatible with _TransformInfo, allowing the redistribute loop to handle both uniformly via mesh override - Add _get_flattened_mesh_by_layout() to query for existing flattened meshes using layout comparison rather than name-based lookup - Add _optimize_transform_infos_for_flattened_reductions() to group same-type reductions and replace with flattened transforms - Modify redistribute_local_tensor() to use mesh_to_use pattern so flattened transforms flow through the same code path as regular transforms Example: For a (2,2,2) mesh with dims (A,B,C) and placements (Partial("sum"), Partial("max"), Partial("sum")), if mesh["A","C"]._flatten() was called, the A and C sums are merged into 1 allreduce (2 comms instead of 3). Test Plan: - Unit tests for _optimize_transform_infos_for_flattened_reductions using fake process group (fast, no NCCL init overhead) - Integration tests verifying comm counts with CommDebugMode -- Claude [ghstack-poisoned]
…shes
When redistributing a DTensor with multiple Partial placements of the same
reduce_op type (e.g., Partial("sum") on dims A and C), this change detects
if a flattened DeviceMesh exists that covers those dimensions and uses a
single allreduce instead of multiple separate ones.
Key changes:
- Add _FlattenedTransformInfo as duck-type compatible with _TransformInfo,
allowing the redistribute loop to handle both uniformly via mesh override
- Add _get_flattened_mesh_by_layout() to query for existing flattened meshes
using layout comparison rather than name-based lookup
- Add _optimize_transform_infos_for_flattened_reductions() to group same-type
reductions (even non-consecutive) and replace with flattened transforms
- Modify redistribute_local_tensor() to use mesh_to_use pattern so flattened
transforms flow through the same code path as regular transforms
Example: For a (2,2,2) mesh with dims (A,B,C) and placements
(Partial("sum"), Partial("max"), Partial("sum")), if mesh["A","C"]._flatten()
was called, the A and C sums are merged into 1 allreduce (2 comms instead of 3).
Test Plan:
- Unit tests for _optimize_transform_infos_for_flattened_reductions using
fake process group (fast, no NCCL init overhead)
- Integration tests verifying comm counts with CommDebugMode
-- Claude
ghstack-source-id: 0e6f6da
Pull Request resolved: #172119
| can_merge = True | ||
| for dim in range(last_mesh_dim + 1, next_info.mesh_dim): | ||
| if src_placements[dim].is_partial(): | ||
| # There's a Partial on a skipped dim - can't merge |
There was a problem hiding this comment.
why would a partial on a skipped dim prevent merge? PPP-> RPR can still be merged, no?
There was a problem hiding this comment.
this is true if the partial being skipped is of the same type + same_reduce_op as the source
There was a problem hiding this comment.
ah, good point. i think i can adjust this to allow skipping a same-type partial and only bail out if its a different type partial.
Updated to include this case.
…lattened meshes" Fixes #171916 When redistributing a DTensor with multiple Partial placements of the same reduce_op type (e.g., Partial("sum") on dims A and C), this change detects if a flattened DeviceMesh exists that covers those dimensions and uses a single allreduce instead of multiple separate ones. Key changes: - Add _FlattenedTransformInfo as duck-type compatible with _TransformInfo, allowing the redistribute loop to handle both uniformly via mesh override - Add _get_flattened_mesh_by_layout() to query for existing flattened meshes using layout comparison rather than name-based lookup - Add _optimize_transform_infos_for_flattened_reductions() to group same-type reductions and replace with flattened transforms - Modify redistribute_local_tensor() to use mesh_to_use pattern so flattened transforms flow through the same code path as regular transforms Example: For a (2,2,2) mesh with dims (A,B,C) and placements (Partial("sum"), Partial("max"), Partial("sum")), if mesh["A","C"]._flatten() was called, the A and C sums are merged into 1 allreduce (2 comms instead of 3). Test Plan: - Unit tests for _optimize_transform_infos_for_flattened_reductions using fake process group (fast, no NCCL init overhead) - Integration tests verifying comm counts with CommDebugMode -- Claude [ghstack-poisoned]
…shes
When redistributing a DTensor with multiple Partial placements of the same
reduce_op type (e.g., Partial("sum") on dims A and C), this change detects
if a flattened DeviceMesh exists that covers those dimensions and uses a
single allreduce instead of multiple separate ones.
Key changes:
- Add _FlattenedTransformInfo as duck-type compatible with _TransformInfo,
allowing the redistribute loop to handle both uniformly via mesh override
- Add _get_flattened_mesh_by_layout() to query for existing flattened meshes
using layout comparison rather than name-based lookup
- Add _optimize_transform_infos_for_flattened_reductions() to group same-type
reductions (even non-consecutive) and replace with flattened transforms
- Modify redistribute_local_tensor() to use mesh_to_use pattern so flattened
transforms flow through the same code path as regular transforms
Example: For a (2,2,2) mesh with dims (A,B,C) and placements
(Partial("sum"), Partial("max"), Partial("sum")), if mesh["A","C"]._flatten()
was called, the A and C sums are merged into 1 allreduce (2 comms instead of 3).
Test Plan:
- Unit tests for _optimize_transform_infos_for_flattened_reductions using
fake process group (fast, no NCCL init overhead)
- Integration tests verifying comm counts with CommDebugMode
-- Claude
ghstack-source-id: 2a07b00
Pull Request resolved: #172119
| IMPORTANT: Reductions can only be merged if there are no Partial placements | ||
| on the mesh dimensions between them. A Partial placement (even if unchanged) | ||
| implies a semantic ordering that must be preserved. For example: | ||
| (Psum, Pmax, Psum) -> the correct order is sum_A, then max_B, then sum_C | ||
| If we merge the sums: sum_{A,C}(x) then max_B gives a different result. | ||
|
|
||
| However, non-Partial placements (Replicate, Shard) in between are safe to | ||
| skip because they don't impose ordering constraints. |
There was a problem hiding this comment.
Does this mean if there are [Psum->R, S(0)->R, Psum->R] in transform_infos, we can reorder as [Psum->R, Psum->R, S(0)->R]? If so, this becomes we can do stable sort for transform_infos and place all Partial operations in the front or back.
There was a problem hiding this comment.
Yes, if you think it's cleaner to write it that way I can change it to do an explicit sorting pass to group like-kind reductions, then do a second pass to merge adjacent like kind reductions.
…lattened meshes" Fixes #171916 When redistributing a DTensor with multiple Partial placements of the same reduce_op type (e.g., Partial("sum") on dims A and C), this change detects if a flattened DeviceMesh exists that covers those dimensions and uses a single allreduce instead of multiple separate ones. Key changes: - Add _FlattenedTransformInfo as duck-type compatible with _TransformInfo, allowing the redistribute loop to handle both uniformly via mesh override - Add _get_flattened_mesh_by_layout() to query for existing flattened meshes using layout comparison rather than name-based lookup - Add _optimize_transform_infos_for_flattened_reductions() to group same-type reductions and replace with flattened transforms - Modify redistribute_local_tensor() to use mesh_to_use pattern so flattened transforms flow through the same code path as regular transforms Example: For a (2,2,2) mesh with dims (A,B,C) and placements (Partial("sum"), Partial("max"), Partial("sum")), if mesh["A","C"]._flatten() was called, the A and C sums are merged into 1 allreduce (2 comms instead of 3). Test Plan: - Unit tests for _optimize_transform_infos_for_flattened_reductions using fake process group (fast, no NCCL init overhead) - Integration tests verifying comm counts with CommDebugMode -- Claude [ghstack-poisoned]
…shes
When redistributing a DTensor with multiple Partial placements of the same
reduce_op type (e.g., Partial("sum") on dims A and C), this change detects
if a flattened DeviceMesh exists that covers those dimensions and uses a
single allreduce instead of multiple separate ones.
Key changes:
- Add _FlattenedTransformInfo as duck-type compatible with _TransformInfo,
allowing the redistribute loop to handle both uniformly via mesh override
- Add _get_flattened_mesh_by_layout() to query for existing flattened meshes
using layout comparison rather than name-based lookup
- Add _optimize_transform_infos_for_flattened_reductions() to group same-type
reductions (even non-consecutive) and replace with flattened transforms
- Modify redistribute_local_tensor() to use mesh_to_use pattern so flattened
transforms flow through the same code path as regular transforms
Example: For a (2,2,2) mesh with dims (A,B,C) and placements
(Partial("sum"), Partial("max"), Partial("sum")), if mesh["A","C"]._flatten()
was called, the A and C sums are merged into 1 allreduce (2 comms instead of 3).
Test Plan:
- Unit tests for _optimize_transform_infos_for_flattened_reductions using
fake process group (fast, no NCCL init overhead)
- Integration tests verifying comm counts with CommDebugMode
-- Claude
ghstack-source-id: 990f175
Pull Request resolved: #172119
| if flattened_mesh._layout == expected_layout: | ||
| return flattened_mesh | ||
|
|
||
| return None |
There was a problem hiding this comment.
we should at the very least warn here and instruct users how to create flattened mesh. Ideally we'd error our (because not having a flattened mesh is a significant perf hit).
There was a problem hiding this comment.
ok- i think i'll add a warn_once here. do you prefer a warn repetedly?
There was a problem hiding this comment.
Tbh I prefer noisy warning here, given that it's an easy remedy
There was a problem hiding this comment.
ok, i put in a warning
| others.append((j, next_info)) | ||
| j += 1 | ||
| elif next_src == src and not has_blocking_partial( | ||
| src, reductions[-1][1].mesh_dim, next_info.mesh_dim |
There was a problem hiding this comment.
This is kind of weird, the has_blocking_partial is checking all mesh dims between A.mesh_dim and B.mesh_dim, which is not related to ordering in transform_infos.
There was a problem hiding this comment.
um, maybe i misunderstood your point. but i think this is actually crucial. See my other PR to see if it convinces you:
#172277
There was a problem hiding this comment.
I see, the assumption is true if we follow the default left-to-right order. Do we want to consider the non default order or the graph based redistribution here?
There was a problem hiding this comment.
what do you mean about non-default order? we have only defined a concept of non-default shard ordering, but that ordering specifically applies to the shard placements it names. I was concluding from that that Partials are always in fixed L to R order, which I further documented in the other PR. If you disagree, lmk and we can clarify what the behavior should be for partials.
There was a problem hiding this comment.
Curious if the transforminfos is below:
mesh dim 0, P(sum)->R
mesh dim 2, P(sum)->R
mesh dim 1, P(avg)->R
Based on the code, we will not merge the first two P(sum)->R, because of mesh dim 1's P(avg)->R. Is this true?
There was a problem hiding this comment.
I got it now. Yea I think that could be a problem- I'll think about it and add a test case for it
There was a problem hiding this comment.
ok- discussed offline, this sequence of transform info should be banned as it is not valid to do redistribution of partial in arbitrary order. we should define a partial order that we always use and then this code is OK.
| dt = DTensor.from_local( | ||
| local_tensor, | ||
| mesh, | ||
| (Partial("sum"), Partial("max"), Partial("sum")), |
There was a problem hiding this comment.
Feel like there maybe an issue with has_blocking_partial. We may need device mesh with a higher dimension to test. I think we can use LocalTensor, though there is no concept of NCCL domain, but we can still test the tensor output just in case the partial sequence get messed up.
With LocalTensor, we probably can mock redistribute_local_tensor to take the a long sequence of made up transform_infos and compare the redistributed tensor with merged partial transform_infos.
There was a problem hiding this comment.
No need to worry. I think it's correct for this PR with the default non-graph based redistribution as in #172119 (comment). I can follow up and verify if we want to extend the support for arbitrary transform_infos sequences.
There was a problem hiding this comment.
I didn't understand these comments. maybe we can discuss offline to clarify.
…lattened meshes" Fixes #171916 When redistributing a DTensor with multiple Partial placements of the same reduce_op type (e.g., Partial("sum") on dims A and C), this change detects if a flattened DeviceMesh exists that covers those dimensions and uses a single allreduce instead of multiple separate ones. Key changes: - Add _FlattenedTransformInfo as duck-type compatible with _TransformInfo, allowing the redistribute loop to handle both uniformly via mesh override - Add _get_flattened_mesh_by_layout() to query for existing flattened meshes using layout comparison rather than name-based lookup - Add _optimize_transform_infos_for_flattened_reductions() to group same-type reductions and replace with flattened transforms - Modify redistribute_local_tensor() to use mesh_to_use pattern so flattened transforms flow through the same code path as regular transforms Example: For a (2,2,2) mesh with dims (A,B,C) and placements (Partial("sum"), Partial("max"), Partial("sum")), if mesh["A","C"]._flatten() was called, the A and C sums are merged into 1 allreduce (2 comms instead of 3). Test Plan: - Unit tests for _optimize_transform_infos_for_flattened_reductions using fake process group (fast, no NCCL init overhead) - Integration tests verifying comm counts with CommDebugMode -- Claude [ghstack-poisoned]
…shes
When redistributing a DTensor with multiple Partial placements of the same
reduce_op type (e.g., Partial("sum") on dims A and C), this change detects
if a flattened DeviceMesh exists that covers those dimensions and uses a
single allreduce instead of multiple separate ones.
Key changes:
- Add _FlattenedTransformInfo as duck-type compatible with _TransformInfo,
allowing the redistribute loop to handle both uniformly via mesh override
- Add _get_flattened_mesh_by_layout() to query for existing flattened meshes
using layout comparison rather than name-based lookup
- Add _optimize_transform_infos_for_flattened_reductions() to group same-type
reductions (even non-consecutive) and replace with flattened transforms
- Modify redistribute_local_tensor() to use mesh_to_use pattern so flattened
transforms flow through the same code path as regular transforms
Example: For a (2,2,2) mesh with dims (A,B,C) and placements
(Partial("sum"), Partial("max"), Partial("sum")), if mesh["A","C"]._flatten()
was called, the A and C sums are merged into 1 allreduce (2 comms instead of 3).
Test Plan:
- Unit tests for _optimize_transform_infos_for_flattened_reductions using
fake process group (fast, no NCCL init overhead)
- Integration tests verifying comm counts with CommDebugMode
-- Claude
ghstack-source-id: 3c178d0
Pull Request resolved: #172119
There was a problem hiding this comment.
I wonder if this fix addresses my concern in #171913 (comment)
In fact, I'm not sure if this is an issue to be solved at DTensor level at all. Do you think we can first detect and error out, before merging a solution which may not be sufficient?
I think it's more urgent to solve the problem for application API, e.g. replicate (DDP) + Sequence Parallel. cc @weifengpy
|
@tianyu-l i made a doc to discuss the tradeoffs: https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?usp=sharing tldr, i think this PR is at best a stopgap solution that does improve performance, from that angle i'm ok with landing it, but i don't love it. I advocate for a bigger change to both (1) devicemesh/processgroup, and (2) redistribution planner in DTensor |
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaur@redhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info. Example: For a (2,2,2) mesh with dims (A,B,C) and placements when redistributing from (Psum, Replicate, Psum) -> (Replicate, Replicate, Replicate) - the original behavior would be 2 separate all_reduces. After this PR, if the user flattens dims A,C, this becomes one larger all_reduce. Compared with earlier attempt #172119, this PR - includes optimization for comms other than all_reduce - explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it - therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds - Warns once per mesh shape for missing flattened meshes Fixes #171916 Note: initial attempt used stable sort with a __lt__ method in TransformInfo comparing comm type key, but this was not correct becuase sorting a local (no-comm) operation like chunking before or after a comm operation on the same mesh time affects results. [ghstack-poisoned]
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info. Example: For a (2,2,2) mesh with dims (A,B,C) and placements when redistributing from (Psum, Replicate, Psum) -> (Replicate, Replicate, Replicate) - the original behavior would be 2 separate all_reduces. After this PR, if the user flattens dims A,C, this becomes one larger all_reduce. Compared with earlier attempt #172119, this PR - includes optimization for comms other than all_reduce - explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it - therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds - Warns once per mesh shape for missing flattened meshes Fixes #171916 Note: initial attempt used stable sort with a __lt__ method in TransformInfo comparing comm type key, but this was not correct becuase sorting a local (no-comm) operation like chunking before or after a comm operation on the same mesh time affects results. ghstack-source-id: a3f464e Pull Request resolved: #172610
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info. Example: For a (2,2,2) mesh with dims (A,B,C) and placements when redistributing from (Psum, Replicate, Psum) -> (Replicate, Replicate, Replicate) - the original behavior would be 2 separate all_reduces. After this PR, if the user flattens dims A,C, this becomes one larger all_reduce. Compared with earlier attempt #172119, this PR - includes optimization for comms other than all_reduce - explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it - therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds - Warns once per mesh shape for missing flattened meshes Fixes #171916 Note: initial attempt used stable sort with a __lt__ method in TransformInfo comparing comm type key, but this was not correct becuase sorting a local (no-comm) operation like chunking before or after a comm operation on the same mesh time affects results. [ghstack-poisoned]
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info. Example: For a (2,2,2) mesh with dims (A,B,C) and placements when redistributing from (Psum, Replicate, Psum) -> (Replicate, Replicate, Replicate) - the original behavior would be 2 separate all_reduces. After this PR, if the user flattens dims A,C, this becomes one larger all_reduce. Compared with earlier attempt #172119, this PR - includes optimization for comms other than all_reduce - explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it - therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds - Warns once per mesh shape for missing flattened meshes Fixes #171916 Note: initial attempt used stable sort with a __lt__ method in TransformInfo comparing comm type key, but this was not correct becuase sorting a local (no-comm) operation like chunking before or after a comm operation on the same mesh time affects results. ghstack-source-id: a76552b Pull Request resolved: #172610
…g flattened meshes" Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info. Example: For a (2,2,2) mesh with dims (A,B,C) and placements when redistributing from (Psum, Replicate, Psum) -> (Replicate, Replicate, Replicate) - the original behavior would be 2 separate all_reduces. After this PR, if the user flattens dims A,C, this becomes one larger all_reduce. Compared with earlier attempt #172119, this PR - includes optimization for comms other than all_reduce - explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it - therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds - Warns once per mesh shape for missing flattened meshes - Won't optimize reduce_scatters when they shard an uneven sized tensor dim Fixes #171916 Note: initial attempt used stable sort with a __lt__ method in TransformInfo comparing comm type key, but this was not correct because sorting a local (no-comm) operation like chunking before or after a comm operation on the same mesh time affects results. [ghstack-poisoned]
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info. Example: For a (2,2,2) mesh with dims (A,B,C) and placements when redistributing from (Psum, Replicate, Psum) -> (Replicate, Replicate, Replicate) - the original behavior would be 2 separate all_reduces. After this PR, if the user flattens dims A,C, this becomes one larger all_reduce. Compared with earlier attempt #172119, this PR - includes optimization for comms other than all_reduce - explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it - therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds - Warns once per mesh shape for missing flattened meshes - Won't optimize reduce_scatters when they shard an uneven sized tensor dim Fixes #171916 Note: initial attempt used stable sort with a __lt__ method in TransformInfo comparing comm type key, but this was not correct because sorting a local (no-comm) operation like chunking before or after a comm operation on the same mesh time affects results. [ghstack-poisoned]
…g flattened meshes" Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info. Example: For a (2,2,2) mesh with dims (A,B,C) and placements when redistributing from (Psum, Replicate, Psum) -> (Replicate, Replicate, Replicate) - the original behavior would be 2 separate all_reduces. After this PR, if the user flattens dims A,C, this becomes one larger all_reduce. Compared with earlier attempt #172119, this PR - includes optimization for comms other than all_reduce - explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it - therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds - Warns once per mesh shape for missing flattened meshes - Won't optimize reduce_scatters when they shard an uneven sized tensor dim Fixes #171916 Note: initial attempt used stable sort with a __lt__ method in TransformInfo comparing comm type key, but this was not correct because sorting a local (no-comm) operation like chunking before or after a comm operation on the same mesh time affects results. [ghstack-poisoned]
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info. Example: For a (2,2,2) mesh with dims (A,B,C) and placements when redistributing from (Psum, Replicate, Psum) -> (Replicate, Replicate, Replicate) - the original behavior would be 2 separate all_reduces. After this PR, if the user flattens dims A,C, this becomes one larger all_reduce. Compared with earlier attempt #172119, this PR - includes optimization for comms other than all_reduce - explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it - therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds - Warns once per mesh shape for missing flattened meshes - Won't optimize reduce_scatters when they shard an uneven sized tensor dim Fixes #171916 Note: initial attempt used stable sort with a __lt__ method in TransformInfo comparing comm type key, but this was not correct because sorting a local (no-comm) operation like chunking before or after a comm operation on the same mesh time affects results. [ghstack-poisoned]
…g flattened meshes" Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info. Example: For a (2,2,2) mesh with dims (A,B,C) and placements when redistributing from (Psum, Replicate, Psum) -> (Replicate, Replicate, Replicate) - the original behavior would be 2 separate all_reduces. After this PR, if the user flattens dims A,C, this becomes one larger all_reduce. Compared with earlier attempt #172119, this PR - includes optimization for comms other than all_reduce - explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it - therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds - Warns once per mesh shape for missing flattened meshes - Won't optimize reduce_scatters when they shard an uneven sized tensor dim Limitations - all_to_all is never merged (left for possible future work, but not obvious how to do it in general) - reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns - groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list - flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization Fixes #171916 Note: initial attempt used stable sort with a __lt__ method in TransformInfo comparing comm type key, but this was not correct because sorting a local (no-comm) operation like chunking before or after a comm operation on the same mesh time affects results. [ghstack-poisoned]
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info. Example: For a (2,2,2) mesh with dims (A,B,C) and placements when redistributing from (Psum, Replicate, Psum) -> (Replicate, Replicate, Replicate) - the original behavior would be 2 separate all_reduces. After this PR, if the user flattens dims A,C, this becomes one larger all_reduce. Compared with earlier attempt #172119, this PR - includes optimization for comms other than all_reduce - explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it - therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds - Warns once per mesh shape for missing flattened meshes - Won't optimize reduce_scatters when they shard an uneven sized tensor dim Limitations - all_to_all is never merged (left for possible future work, but not obvious how to do it in general) - reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns - groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list - flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization Fixes #171916 Note: initial attempt used stable sort with a __lt__ method in TransformInfo comparing comm type key, but this was not correct because sorting a local (no-comm) operation like chunking before or after a comm operation on the same mesh time affects results. [ghstack-poisoned]
…g flattened meshes" Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info. Example: For a (2,2,2) mesh with dims (A,B,C) and placements when redistributing from (Psum, Replicate, Psum) -> (Replicate, Replicate, Replicate) - the original behavior would be 2 separate all_reduces. After this PR, if the user flattens dims A,C, this becomes one larger all_reduce. Compared with earlier attempt #172119, this PR - includes optimization for comms other than all_reduce - explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it - therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds - Warns once per mesh shape for missing flattened meshes - Won't optimize reduce_scatters when they shard an uneven sized tensor dim Limitations - all_to_all is never merged (left for possible future work, but not obvious how to do it in general) - reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns - groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list - flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization Fixes #171916 Note: initial attempt used stable sort with a __lt__ method in TransformInfo comparing comm type key, but this was not correct because sorting a local (no-comm) operation like chunking before or after a comm operation on the same mesh time affects results. [ghstack-poisoned]
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info. Example: For a (2,2,2) mesh with dims (A,B,C) and placements when redistributing from (Psum, Replicate, Psum) -> (Replicate, Replicate, Replicate) - the original behavior would be 2 separate all_reduces. After this PR, if the user flattens dims A,C, this becomes one larger all_reduce. Compared with earlier attempt #172119, this PR - includes optimization for comms other than all_reduce - explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it - therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds - Warns once per mesh shape for missing flattened meshes - Won't optimize reduce_scatters when they shard an uneven sized tensor dim Limitations - all_to_all is never merged (left for possible future work, but not obvious how to do it in general) - reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns - groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list - flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization Fixes #171916 Note: initial attempt used stable sort with a __lt__ method in TransformInfo comparing comm type key, but this was not correct because sorting a local (no-comm) operation like chunking before or after a comm operation on the same mesh time affects results. [ghstack-poisoned]
…g flattened meshes" Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info. Example: For a (2,2,2) mesh with dims (A,B,C) and placements when redistributing from (Psum, Replicate, Psum) -> (Replicate, Replicate, Replicate) - the original behavior would be 2 separate all_reduces. After this PR, if the user flattens dims A,C, this becomes one larger all_reduce. Compared with earlier attempt #172119, this PR - includes optimization for comms other than all_reduce - explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it - therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds - Warns once per mesh shape for missing flattened meshes - Won't optimize reduce_scatters when they shard an uneven sized tensor dim Limitations - all_to_all is never merged (left for possible future work, but not obvious how to do it in general) - reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns - groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list - flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization Fixes #171916 Note: initial attempt used stable sort with a __lt__ method in TransformInfo comparing comm type key, but this was not correct because sorting a local (no-comm) operation like chunking before or after a comm operation on the same mesh time affects results. [ghstack-poisoned]
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info. Example: For a (2,2,2) mesh with dims (A,B,C) and placements when redistributing from (Psum, Replicate, Psum) -> (Replicate, Replicate, Replicate) - the original behavior would be 2 separate all_reduces. After this PR, if the user flattens dims A,C, this becomes one larger all_reduce. Compared with earlier attempt #172119, this PR - includes optimization for comms other than all_reduce - explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it - therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds - Warns once per mesh shape for missing flattened meshes - Won't optimize reduce_scatters when they shard an uneven sized tensor dim Limitations - all_to_all is never merged (left for possible future work, but not obvious how to do it in general) - reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns - groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list - flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization Fixes #171916 Note: initial attempt used stable sort with a __lt__ method in TransformInfo comparing comm type key, but this was not correct because sorting a local (no-comm) operation like chunking before or after a comm operation on the same mesh time affects results. [ghstack-poisoned]
…g flattened meshes" Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info. Example: For a (2,2,2) mesh with dims (A,B,C) and placements when redistributing from (Psum, Replicate, Psum) -> (Replicate, Replicate, Replicate) - the original behavior would be 2 separate all_reduces. After this PR, if the user flattens dims A,C, this becomes one larger all_reduce. Compared with earlier attempt #172119, this PR - includes optimization for comms other than all_reduce - explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it - therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds - Warns once per mesh shape for missing flattened meshes - Won't optimize reduce_scatters when they shard an uneven sized tensor dim Limitations - all_to_all is never merged (left for possible future work, but not obvious how to do it in general) - reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns - groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list - flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization Fixes #171916 Note: initial attempt used stable sort with a __lt__ method in TransformInfo comparing comm type key, but this was not correct because sorting a local (no-comm) operation like chunking before or after a comm operation on the same mesh time affects results. [ghstack-poisoned]
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info. Example: For a (2,2,2) mesh with dims (A,B,C) and placements when redistributing from (Psum, Replicate, Psum) -> (Replicate, Replicate, Replicate) - the original behavior would be 2 separate all_reduces. After this PR, if the user flattens dims A,C, this becomes one larger all_reduce. Compared with earlier attempt #172119, this PR - includes optimization for comms other than all_reduce - explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it - therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds - Warns once per mesh shape for missing flattened meshes - Won't optimize reduce_scatters when they shard an uneven sized tensor dim Limitations - all_to_all is never merged (left for possible future work, but not obvious how to do it in general) - reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns - groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list - flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization Fixes #171916 Note: initial attempt used stable sort with a __lt__ method in TransformInfo comparing comm type key, but this was not correct because sorting a local (no-comm) operation like chunking before or after a comm operation on the same mesh time affects results. [ghstack-poisoned]
…g flattened meshes" Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info. Example: For a (2,2,2) mesh with dims (A,B,C) and placements when redistributing from (Psum, Replicate, Psum) -> (Replicate, Replicate, Replicate) - the original behavior would be 2 separate all_reduces. After this PR, if the user flattens dims A,C, this becomes one larger all_reduce. Compared with earlier attempt #172119, this PR - includes optimization for comms other than all_reduce - explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it - therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds - Warns once per mesh shape for missing flattened meshes - Won't optimize reduce_scatters when they shard an uneven sized tensor dim Limitations - all_to_all is never merged (left for possible future work, but not obvious how to do it in general) - reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns - groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list - flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization Fixes #171916 Note: initial attempt used stable sort with a __lt__ method in TransformInfo comparing comm type key, but this was not correct because sorting a local (no-comm) operation like chunking before or after a comm operation on the same mesh time affects results. [ghstack-poisoned]
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info. Example: For a (2,2,2) mesh with dims (A,B,C) and placements when redistributing from (Psum, Replicate, Psum) -> (Replicate, Replicate, Replicate) - the original behavior would be 2 separate all_reduces. After this PR, if the user flattens dims A,C, this becomes one larger all_reduce. Compared with earlier attempt #172119, this PR - includes optimization for comms other than all_reduce - explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it - therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds - Warns once per mesh shape for missing flattened meshes - Won't optimize reduce_scatters when they shard an uneven sized tensor dim Limitations - all_to_all is never merged (left for possible future work, but not obvious how to do it in general) - reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns - groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list - flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization Fixes #171916 Note: initial attempt used stable sort with a __lt__ method in TransformInfo comparing comm type key, but this was not correct because sorting a local (no-comm) operation like chunking before or after a comm operation on the same mesh time affects results. [ghstack-poisoned]
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info. Example: For a (2,2,2) mesh with dims (A,B,C) and placements when redistributing from (Psum, Replicate, Psum) -> (Replicate, Replicate, Replicate) - the original behavior would be 2 separate all_reduces. After this PR, if the user flattens dims A,C, this becomes one larger all_reduce. Compared with earlier attempt #172119, this PR - includes optimization for comms other than all_reduce - explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it - therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds - Warns once per mesh shape for missing flattened meshes Fixes #171916 Note: initial attempt used stable sort with a __lt__ method in TransformInfo comparing comm type key, but this was not correct because sorting a local (no-comm) operation like chunking before or after a comm operation on the same mesh time affects results. ghstack-source-id: 48bb7c0 Pull Request resolved: #172610
…g flattened meshes" Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info. Example: For a (2,2,2) mesh with dims (A,B,C) and placements when redistributing from (Psum, Replicate, Psum) -> (Replicate, Replicate, Replicate) - the original behavior would be 2 separate all_reduces. After this PR, if the user flattens dims A,C, this becomes one larger all_reduce. Compared with earlier attempt #172119, this PR - includes optimization for comms other than all_reduce - explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it - therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds - Warns once per mesh shape for missing flattened meshes - Won't optimize reduce_scatters when they shard an uneven sized tensor dim Limitations - all_to_all is never merged (left for possible future work, but not obvious how to do it in general) - reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns - groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list - flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization Fixes #171916 Note: initial attempt used stable sort with a __lt__ method in TransformInfo comparing comm type key, but this was not correct because sorting a local (no-comm) operation like chunking before or after a comm operation on the same mesh time affects results. [ghstack-poisoned]
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info. Example: For a (2,2,2) mesh with dims (A,B,C) and placements when redistributing from (Psum, Replicate, Psum) -> (Replicate, Replicate, Replicate) - the original behavior would be 2 separate all_reduces. After this PR, if the user flattens dims A,C, this becomes one larger all_reduce. Compared with earlier attempt #172119, this PR - includes optimization for comms other than all_reduce - explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it - therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds - Warns once per mesh shape for missing flattened meshes - Won't optimize reduce_scatters when they shard an uneven sized tensor dim Limitations - all_to_all is never merged (left for possible future work, but not obvious how to do it in general) - reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns - groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list - flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization Fixes #171916 Note: initial attempt used stable sort with a __lt__ method in TransformInfo comparing comm type key, but this was not correct because sorting a local (no-comm) operation like chunking before or after a comm operation on the same mesh time affects results. [ghstack-poisoned]
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info. Example: For a (2,2,2) mesh with dims (A,B,C) and placements when redistributing from (Psum, Replicate, Psum) -> (Replicate, Replicate, Replicate) - the original behavior would be 2 separate all_reduces. After this PR, if the user flattens dims A,C, this becomes one larger all_reduce. Compared with earlier attempt #172119, this PR - includes optimization for comms other than all_reduce - explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it - therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds - Warns once per mesh shape for missing flattened meshes Fixes #171916 Note: initial attempt used stable sort with a __lt__ method in TransformInfo comparing comm type key, but this was not correct because sorting a local (no-comm) operation like chunking before or after a comm operation on the same mesh time affects results. ghstack-source-id: 2b47663 Pull Request resolved: #172610
…g flattened meshes" Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info. Example: For a (2,2,2) mesh with dims (A,B,C) and placements when redistributing from (Psum, Replicate, Psum) -> (Replicate, Replicate, Replicate) - the original behavior would be 2 separate all_reduces. After this PR, if the user flattens dims A,C, this becomes one larger all_reduce. Compared with earlier attempt #172119, this PR - includes optimization for comms other than all_reduce - explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it - therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds - Warns once per mesh shape for missing flattened meshes - Won't optimize reduce_scatters when they shard an uneven sized tensor dim Details/Limitations - all_to_all is never merged (left for possible future work, but not obvious how to do it in general) - reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns - reduce_scatter is only merged when the shards are in left-to-right (ascending) order, since DeviceMesh only supports flattening in ascending order and the mesh ordering impacts correctness for merged reduce_scatter - groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list - flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization - DOES support merging mixed Partial (sum, avg) reductions, using the product of the avg dim sizes to scale after performing a sum reduction on the merged mesh. Refuses to merge any other combinations of mixed partials. Fixes #171916 Note: initial attempt used stable sort with a __lt__ method in TransformInfo comparing comm type key, but this was not correct because sorting a local (no-comm) operation like chunking before or after a comm operation on the same mesh time affects results. [ghstack-poisoned]
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info. Example: For a (2,2,2) mesh with dims (A,B,C) and placements when redistributing from (Psum, Replicate, Psum) -> (Replicate, Replicate, Replicate) - the original behavior would be 2 separate all_reduces. After this PR, if the user flattens dims A,C, this becomes one larger all_reduce. Compared with earlier attempt #172119, this PR - includes optimization for comms other than all_reduce - explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it - therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds - Warns once per mesh shape for missing flattened meshes - Won't optimize reduce_scatters when they shard an uneven sized tensor dim Details/Limitations - all_to_all is never merged (left for possible future work, but not obvious how to do it in general) - reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns - reduce_scatter is only merged when the shards are in left-to-right (ascending) order, since DeviceMesh only supports flattening in ascending order and the mesh ordering impacts correctness for merged reduce_scatter - groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list - flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization - DOES support merging mixed Partial (sum, avg) reductions, using the product of the avg dim sizes to scale after performing a sum reduction on the merged mesh. Refuses to merge any other combinations of mixed partials. Fixes #171916 Note: initial attempt used stable sort with a __lt__ method in TransformInfo comparing comm type key, but this was not correct because sorting a local (no-comm) operation like chunking before or after a comm operation on the same mesh time affects results. [ghstack-poisoned]
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info. Example: For a (2,2,2) mesh with dims (A,B,C) and placements when redistributing from (Psum, Replicate, Psum) -> (Replicate, Replicate, Replicate) - the original behavior would be 2 separate all_reduces. After this PR, if the user flattens dims A,C, this becomes one larger all_reduce. Compared with earlier attempt #172119, this PR - includes optimization for comms other than all_reduce - explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it - therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds - Warns once per mesh shape for missing flattened meshes Fixes #171916 Note: initial attempt used stable sort with a __lt__ method in TransformInfo comparing comm type key, but this was not correct because sorting a local (no-comm) operation like chunking before or after a comm operation on the same mesh time affects results. ghstack-source-id: d44f8bb Pull Request resolved: #172610
…g flattened meshes" Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info. Example: For a (2,2,2) mesh with dims (A,B,C) and placements when redistributing from (Psum, Replicate, Psum) -> (Replicate, Replicate, Replicate) - the original behavior would be 2 separate all_reduces. After this PR, if the user flattens dims A,C, this becomes one larger all_reduce. Compared with earlier attempt #172119, this PR - includes optimization for comms other than all_reduce - explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it - therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds - Warns once per mesh shape for missing flattened meshes - Won't optimize reduce_scatters when they shard an uneven sized tensor dim Details/Limitations - all_to_all is never merged (left for possible future work, but not obvious how to do it in general) - reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns - reduce_scatter is only merged when the shards are in left-to-right (ascending) order, since DeviceMesh only supports flattening in ascending order and the mesh ordering impacts correctness for merged reduce_scatter - groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list - flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization - DOES support merging mixed Partial (sum, avg) reductions, using the product of the avg dim sizes to scale after performing a sum reduction on the merged mesh. Refuses to merge any other combinations of mixed partials. Fixes #171916 Note: initial attempt used stable sort with a __lt__ method in TransformInfo comparing comm type key, but this was not correct because sorting a local (no-comm) operation like chunking before or after a comm operation on the same mesh time affects results. [ghstack-poisoned]
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info. Example: For a (2,2,2) mesh with dims (A,B,C) and placements when redistributing from (Psum, Replicate, Psum) -> (Replicate, Replicate, Replicate) - the original behavior would be 2 separate all_reduces. After this PR, if the user flattens dims A,C, this becomes one larger all_reduce. Compared with earlier attempt #172119, this PR - includes optimization for comms other than all_reduce - explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it - therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds - Warns once per mesh shape for missing flattened meshes - Won't optimize reduce_scatters when they shard an uneven sized tensor dim Details/Limitations - all_to_all is never merged (left for possible future work, but not obvious how to do it in general) - reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns - reduce_scatter is only merged when the shards are in left-to-right (ascending) order, since DeviceMesh only supports flattening in ascending order and the mesh ordering impacts correctness for merged reduce_scatter - groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list - flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization - DOES support merging mixed Partial (sum, avg) reductions, using the product of the avg dim sizes to scale after performing a sum reduction on the merged mesh. Refuses to merge any other combinations of mixed partials. Fixes #171916 Note: initial attempt used stable sort with a __lt__ method in TransformInfo comparing comm type key, but this was not correct because sorting a local (no-comm) operation like chunking before or after a comm operation on the same mesh time affects results. [ghstack-poisoned]
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info. Example: For a (2,2,2) mesh with dims (A,B,C) and placements when redistributing from (Psum, Replicate, Psum) -> (Replicate, Replicate, Replicate) - the original behavior would be 2 separate all_reduces. After this PR, if the user flattens dims A,C, this becomes one larger all_reduce. Compared with earlier attempt #172119, this PR - includes optimization for comms other than all_reduce - explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it - therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds - Warns once per mesh shape for missing flattened meshes Fixes #171916 Note: initial attempt used stable sort with a __lt__ method in TransformInfo comparing comm type key, but this was not correct because sorting a local (no-comm) operation like chunking before or after a comm operation on the same mesh time affects results. ghstack-source-id: 8acd9a8 Pull Request resolved: #172610
…g flattened meshes" Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info. Example: For a (2,2,2) mesh with dims (A,B,C) and placements when redistributing from (Psum, Replicate, Psum) -> (Replicate, Replicate, Replicate) - the original behavior would be 2 separate all_reduces. After this PR, if the user flattens dims A,C, this becomes one larger all_reduce. Compared with earlier attempt #172119, this PR - includes optimization for comms other than all_reduce - explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it - therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds - Warns once per mesh shape for missing flattened meshes - Won't optimize reduce_scatters when they shard an uneven sized tensor dim Details/Limitations - all_to_all is never merged (left for possible future work, but not obvious how to do it in general) - reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns - reduce_scatter and all_gather are only merged when the shards are in left-to-right (ascending) order, since DeviceMesh only supports flattening in ascending order and the mesh ordering impacts correctness. - groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list - flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization - DOES support merging mixed Partial (sum, avg) reductions, using the product of the avg dim sizes to scale after performing a sum reduction on the merged mesh. Refuses to merge any other combinations of mixed partials. Fixes #171916 Note: initial attempt used stable sort with a __lt__ method in TransformInfo comparing comm type key, but this was not correct because sorting a local (no-comm) operation like chunking before or after a comm operation on the same mesh time affects results. [ghstack-poisoned]
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info. Example: For a (2,2,2) mesh with dims (A,B,C) and placements when redistributing from (Psum, Replicate, Psum) -> (Replicate, Replicate, Replicate) - the original behavior would be 2 separate all_reduces. After this PR, if the user flattens dims A,C, this becomes one larger all_reduce. Compared with earlier attempt #172119, this PR - includes optimization for comms other than all_reduce - explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it - therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds - Warns once per mesh shape for missing flattened meshes - Won't optimize reduce_scatters when they shard an uneven sized tensor dim Details/Limitations - all_to_all is never merged (left for possible future work, but not obvious how to do it in general) - reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns - reduce_scatter and all_gather are only merged when the shards are in left-to-right (ascending) order, since DeviceMesh only supports flattening in ascending order and the mesh ordering impacts correctness. - groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list - flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization - DOES support merging mixed Partial (sum, avg) reductions, using the product of the avg dim sizes to scale after performing a sum reduction on the merged mesh. Refuses to merge any other combinations of mixed partials. Fixes #171916 Note: initial attempt used stable sort with a __lt__ method in TransformInfo comparing comm type key, but this was not correct because sorting a local (no-comm) operation like chunking before or after a comm operation on the same mesh time affects results. [ghstack-poisoned]
…g flattened meshes" Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info. Example: For a (2,2,2) mesh with dims (A,B,C) and placements when redistributing from (Psum, Replicate, Psum) -> (Replicate, Replicate, Replicate) - the original behavior would be 2 separate all_reduces. After this PR, if the user flattens dims A,C, this becomes one larger all_reduce. Compared with earlier attempt #172119, this PR - includes optimization for comms other than all_reduce - explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it - therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds - Warns once per mesh shape for missing flattened meshes - Won't optimize reduce_scatters when they shard an uneven sized tensor dim Details/Limitations - all_to_all is never merged (left for possible future work, but not obvious how to do it in general) - reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns - reduce_scatter and all_gather are only merged when the shards are in left-to-right (ascending) order, since DeviceMesh only supports flattening in ascending order and the mesh ordering impacts correctness. - groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list - flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization - DOES support merging mixed Partial (sum, avg) reductions, using the product of the avg dim sizes to scale after performing a sum reduction on the merged mesh. Refuses to merge any other combinations of mixed partials. Fixes #171916 Note: initial attempt used stable sort with a __lt__ method in TransformInfo comparing comm type key, but this was not correct because sorting a local (no-comm) operation like chunking before or after a comm operation on the same mesh time affects results. [ghstack-poisoned]
Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info. Example: For a (2,2,2) mesh with dims (A,B,C) and placements when redistributing from (Psum, Replicate, Psum) -> (Replicate, Replicate, Replicate) - the original behavior would be 2 separate all_reduces. After this PR, if the user flattens dims A,C, this becomes one larger all_reduce. Compared with earlier attempt #172119, this PR - includes optimization for comms other than all_reduce - explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it - therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds - Warns once per mesh shape for missing flattened meshes - Won't optimize reduce_scatters when they shard an uneven sized tensor dim Details/Limitations - all_to_all is never merged (left for possible future work, but not obvious how to do it in general) - reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns - reduce_scatter and all_gather are only merged when the shards are in left-to-right (ascending) order, since DeviceMesh only supports flattening in ascending order and the mesh ordering impacts correctness. - groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list - flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization - DOES support merging mixed Partial (sum, avg) reductions, using the product of the avg dim sizes to scale after performing a sum reduction on the merged mesh. Refuses to merge any other combinations of mixed partials. Fixes #171916 Note: initial attempt used stable sort with a __lt__ method in TransformInfo comparing comm type key, but this was not correct because sorting a local (no-comm) operation like chunking before or after a comm operation on the same mesh time affects results. [ghstack-poisoned]
…g flattened meshes" Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaurredhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info. Example: For a (2,2,2) mesh with dims (A,B,C) and placements when redistributing from (Psum, Replicate, Psum) -> (Replicate, Replicate, Replicate) - the original behavior would be 2 separate all_reduces. After this PR, if the user flattens dims A,C, this becomes one larger all_reduce. Compared with earlier attempt #172119, this PR - includes optimization for comms other than all_reduce - explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it - therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds - Warns once per mesh shape for missing flattened meshes - Won't optimize reduce_scatters when they shard an uneven sized tensor dim Details/Limitations - all_to_all is never merged (left for possible future work, but not obvious how to do it in general) - reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns - reduce_scatter and all_gather are only merged when the shards are in left-to-right (ascending) order, since DeviceMesh only supports flattening in ascending order and the mesh ordering impacts correctness. - groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list - flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization - DOES support merging mixed Partial (sum, avg) reductions, using the product of the avg dim sizes to scale after performing a sum reduction on the merged mesh. Refuses to merge any other combinations of mixed partials. Fixes #171916 Note: initial attempt used stable sort with a __lt__ method in TransformInfo comparing comm type key, but this was not correct because sorting a local (no-comm) operation like chunking before or after a comm operation on the same mesh time affects results. [ghstack-poisoned]
Stack from ghstack (oldest at bottom):
Fixes #171916
When redistributing a DTensor with multiple Partial placements of the same
reduce_op type (e.g., Partial("sum") on dims A and C), this change detects
if a flattened DeviceMesh exists that covers those dimensions and uses a
single allreduce instead of multiple separate ones.
Key changes:
allowing the redistribute loop to handle both uniformly via mesh override
using layout comparison rather than name-based lookup
reductions and replace with flattened transforms
transforms flow through the same code path as regular transforms
Example: For a (2,2,2) mesh with dims (A,B,C) and placements
(Partial("sum"), Partial("max"), Partial("sum")), if mesh["A","C"]._flatten()
was called, the A and C sums are merged into 1 allreduce (2 comms instead of 3).
Test Plan:
fake process group (fast, no NCCL init overhead)
-- Claude