Skip to content

[DTensor] Optimize redistribute to use flattened mesh dims for consecutive reductions#171913

Closed
ezyang wants to merge 1 commit intogh/ezyang/3233/basefrom
gh/ezyang/3233/head
Closed

[DTensor] Optimize redistribute to use flattened mesh dims for consecutive reductions#171913
ezyang wants to merge 1 commit intogh/ezyang/3233/basefrom
gh/ezyang/3233/head

Conversation

@ezyang
Copy link
Copy Markdown
Contributor

@ezyang ezyang commented Jan 7, 2026

Stack from ghstack (oldest at bottom):

Authored with claude code

When there are multiple reductions that need to occur on multiple mesh
dims, we will issue multiple collectives per mesh dim. When we want to
do a reduction on multiple contiguous mesh dims, and a flattened dim of
those contiguous dims exists (e.g., we have already paid for
initializing PGs for the flattened dim), then it would be better to do
the reduction all in one go on the flattened mesh dim.

The redistribute algorithm currently operates by proposing a sequence of
collectives to perform. This change looks for multiple consecutive
reductions, and greedily tests if they have a flattened mesh dim/PG. If
they do, it replaces this plan with one that does the reduction all in
one step.

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Jan 7, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/171913

Note: Links to docs will display an error until the docs builds have been completed.

❌ 3 New Failures

As of commit 6590dfa with merge base 3d2e7de (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

ezyang added a commit that referenced this pull request Jan 7, 2026
…utive reductions

Authored with claude code

When there are multiple reductions that need to occur on multiple mesh
dims, we will issue multiple collectives per mesh dim. When we want to
do a reduction on multiple contiguous mesh dims, and a flattened dim of
those contiguous dims exists (e.g., we have already paid for
initializing PGs for the flattened dim), then it would be better to do
the reduction all in one go on the flattened mesh dim.

The redistribute algorithm currently operates by proposing a sequence of
collectives to perform. This change looks for multiple consecutive
reductions, and greedily tests if they have a flattened mesh dim/PG. If
they do, it replaces this plan with one that does the reduction all in
one step.


ghstack-source-id: 7c4441a
Pull-Request: #171913
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Jan 7, 2026

This PR needs a release notes: label

If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

) -> list[_TransformInfo]:
"""
Optimize transform_infos by merging consecutive all-reduce operations on
contiguous mesh dimensions when a flattened mesh/PG exists for those dimensions.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it only possible to merge reductions on contiguous mesh dims?

further, if we had a mesh like [dp, pp, tp] and we sliced out spmd_mesh = parent[dp, tp] - the indices of dp, tp would be 0,1 and appear 'contiguous' to this code, but not actually be. is that a problem?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we can leverage the coalesce of the layout, that will solve this case @wconstab mentioned here.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying to find a corner case where the contiguous assumption does not provide identity guarantee:

If we use graph-based DDP + FSDP + TP, for RMSNorm.weight we would have

  • param (Replicate, Shard, Replicate)
  • grad before reduction (Partial, Partial, Partial)

I guess reduction will still happen in order AR, RS, AR and result in DDP / TP ranks not having the same results.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am pretty sure contiguous only is sound, but not complete (as wconstab is mentioning above). Are you worried about unsoundness here too?

Copy link
Copy Markdown
Contributor

@tianyu-l tianyu-l Jan 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do worry about completeness. If a moderately complicated solution doesn't solve the problem, I would prefer we error out.

@ezyang
Copy link
Copy Markdown
Contributor Author

ezyang commented Jan 7, 2026

Also, I don't care AT ALL about the code here (entirely claude coded), so if someone wants to redo it from scratch or commandeer, I am not trying to lick the cookie.

@wconstab
Copy link
Copy Markdown
Contributor

I did take over this PR and replace it with #172121 FYI. closing this one

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants