Skip to content

[DTensor] Optimize redistribute comms using flattened meshes#174630

Closed
wconstab wants to merge 1 commit intomainfrom
export-D92540256
Closed

[DTensor] Optimize redistribute comms using flattened meshes#174630
wconstab wants to merge 1 commit intomainfrom
export-D92540256

Conversation

@wconstab
Copy link
Copy Markdown
Contributor

@wconstab wconstab commented Feb 9, 2026

Reland of #172610: same code as previous land except:

Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See this doc for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces. After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR

  • includes optimization for comms other than all_reduce
  • explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
  • therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
  • Warns once per mesh shape for missing flattened meshes
  • Won't optimize reduce_scatters when they shard an uneven sized tensor dim

Details/Limitations

  • all_to_all is never merged (left for possible future work, but not obvious how to do it in general)
  • reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns
  • reduce_scatter and all_gather are only merged when the shards are in left-to-right (ascending) order, since DeviceMesh only supports flattening in ascending order and the mesh ordering impacts correctness.
  • groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list
  • flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization
  • DOES support merging mixed Partial (sum, avg) reductions, using the product of the avg dim sizes to scale after performing a sum reduction on the merged mesh. Refuses to merge any other combinations of mixed partials.

Fixes #171916

Note: initial attempt used stable sort with a lt
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

Differential Revision: D92540256

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Feb 9, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/174630

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (6 Unrelated Failures)

As of commit c170b91 with merge base 4674618 (image):

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-codesync
Copy link
Copy Markdown

meta-codesync Bot commented Feb 9, 2026

@wconstab has exported this pull request. If you are a Meta employee, you can view the originating Diff in D92540256.

Copy link
Copy Markdown
Member

@zpcore zpcore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@pytorch-bot pytorch-bot Bot added the ciflow/trunk Trigger trunk jobs on your pull request label Feb 9, 2026
Summary:

Reland of #172610
- includes fixes #173873 (credit bdhirsh) and #173790 (credit IvanKobzarev)

Ensures that when possible (when such a flattened mesh exists), DTensor will find and use it to avoid more costly sequential comms, and particularly for reduce comms, also avoids the risk of different reduction orders causing divergent results. (See [this doc](https://docs.google.com/document/d/1hJsnodQmHfs1QosNgR39HZNiOOzfnZ6bnALqonDpcDs/edit?userstoinvite=rrathaur@redhat.com&sharingaction=manageaccess&role=reader&tab=t.0) for more info.

Example: For a (2,2,2) mesh with dims (A,B,C) and placements
when redistributing from (Psum, Replicate, Psum) -> (Replicate,
Replicate, Replicate) - the original behavior would be 2 separate
all_reduces.  After this PR, if the user flattens dims A,C, this becomes
one larger all_reduce.

Compared with earlier attempt #172119, this PR
- includes optimization for comms other than all_reduce
- explicitly bans mixed partial types (Psum, Pmax) is not a valid placement, so we don't have to worry about optimizing around it
- therefore uses a simpler implementation involving grouping adjacent transforminfos and then merging like kinds
- Warns once per mesh shape for missing flattened meshes
- Won't optimize reduce_scatters when they shard an uneven sized tensor dim

Details/Limitations
- all_to_all is never merged (left for possible future work, but not obvious how to do it in general)
- reduce_scatter is only merged when the outermost partial shape is evenly divisible by the flattened mesh - otherwise, warns
- reduce_scatter and all_gather are only merged when the shards are in left-to-right (ascending) order, since DeviceMesh only supports flattening in ascending order and the mesh ordering impacts correctness.
- groups of like-kind collectives are NOT combined if they are not adjacent in the transform_info list
- flattened device-meshes are not automatically created due to preference of explicit creation and ensuring torch.compile works, but warnings prompt the user to create them when it would help allow an optimization
- DOES support merging mixed Partial (sum, avg) reductions, using the product of the avg dim sizes to scale after performing a sum reduction on the merged mesh.  Refuses to merge any other combinations of mixed partials.

Fixes #171916

Note: initial attempt used stable sort with a __lt__
method in TransformInfo comparing comm type key, but this was not correct because sorting a
local (no-comm) operation like chunking before or after a comm operation
on the same mesh time affects results.

Differential Revision: D92540256
@facebook-github-bot
Copy link
Copy Markdown
Contributor

@pytorchbot merge

(Initiating merge automatically since Phabricator Diff has merged)

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions github-actions Bot deleted the export-D92540256 branch March 13, 2026 02:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

DTensor must generate flattened PGs to avoid allreduce result inconsistency across Replicate when reducing over multiple mesh dims

4 participants