Skip to content

[DTensor] Make RedistributionPlanner handle all partials#172479

Closed
wconstab wants to merge 10 commits intogh/wconstab/501/basefrom
gh/wconstab/501/head
Closed

[DTensor] Make RedistributionPlanner handle all partials#172479
wconstab wants to merge 10 commits intogh/wconstab/501/basefrom
gh/wconstab/501/head

Conversation

@wconstab
Copy link
Copy Markdown
Contributor

@wconstab wconstab commented Jan 14, 2026

Stack from ghstack (oldest at bottom):

Previously, the planner hardcodes psum and ignores other partials. This means if we tried to redistribute to pavg or pmax we'd fail.

Changes Made

  1. torch/distributed/tensor/_redistribute.py

Added partial_reduce_ops_in_target field (line 313):

  • Added a new instance variable partial_reduce_ops_in_target: set[str] = set() to track which Partial reduce ops are present in the src/dst placements.

Modified reduce op collection (lines 749-754):

  • Added code to collect Partial reduce ops from both src and dst placements when planning the redistribution. This ensures only relevant reduce ops are considered.

Updated R->P transition generation (lines 536-552):

  • Changed the hardcoded ("sum", "avg") to use self.partial_reduce_ops_in_target, which dynamically considers only the reduce ops present in the src/dst placements.
  1. test/distributed/tensor/test_redistribute.py

Added test_replicate_to_partial_different_reduce_ops (lines 903-950):

  • Tests that R->P transitions work correctly for all reduce op types (sum, avg, min, max).
  • Verifies the local tensor content is correct based on the reduce_op semantics.

Added test_replicate_to_partial_planner_reduce_op_collection (lines 952-1054):

  • Tests that the planner correctly collects reduce ops from src/dst placements.
  • Verifies the optimization that avoids naively expanding the graph to include all reduce op types.
  • Tests three scenarios: R->P("min"), P("max")->R, and multi-dimensional meshes with multiple Partial types.

Key Benefits

  1. Dynamic reduce op handling: The planner now considers only reduce ops present in the actual redistribution request, rather than hardcoding specific reduce ops.
  2. No unnecessary graph expansion: By only considering relevant reduce ops, the graph-based search avoids exploring paths that aren't needed.
  3. Full reduce op support: All reduce op types (sum, avg, min, max, etc.) are now supported for R->P transitions, not just sum and avg.

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Jan 14, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/172479

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit 6339cf2 with merge base 7754b55 (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

SergeyTyshkevich pushed a commit to SergeyTyshkevich/chart2 that referenced this pull request Jan 19, 2026
suncapitalllc007-star pushed a commit to suncapitalllc007-star/pytorch that referenced this pull request Jan 25, 2026
Previously, the planner hardcodes psum and ignores other partials.  This means if we tried to redistribute to pavg or pmax we'd fail.

  Changes Made

  1. torch/distributed/tensor/_redistribute.py

  Added partial_reduce_ops_in_target field (line 313):
  - Added a new instance variable partial_reduce_ops_in_target: set[str] = set() to track which Partial reduce ops are present in the src/dst placements.

  Modified reduce op collection (lines 749-754):
  - Added code to collect Partial reduce ops from both src and dst placements when planning the redistribution. This ensures only relevant reduce ops are considered.

  Updated R->P transition generation (lines 536-552):
  - Changed the hardcoded ("sum", "avg") to use self.partial_reduce_ops_in_target, which dynamically considers only the reduce ops present in the src/dst placements.

  2. test/distributed/tensor/test_redistribute.py

  Added test_replicate_to_partial_different_reduce_ops (lines 903-950):
  - Tests that R->P transitions work correctly for all reduce op types (sum, avg, min, max).
  - Verifies the local tensor content is correct based on the reduce_op semantics.

  Added test_replicate_to_partial_planner_reduce_op_collection (lines 952-1054):
  - Tests that the planner correctly collects reduce ops from src/dst placements.
  - Verifies the optimization that avoids naively expanding the graph to include all reduce op types.
  - Tests three scenarios: R->P("min"), P("max")->R, and multi-dimensional meshes with multiple Partial types.

  Key Benefits

  1. Dynamic reduce op handling: The planner now considers only reduce ops present in the actual redistribution request, rather than hardcoding specific reduce ops.
  2. No unnecessary graph expansion: By only considering relevant reduce ops, the graph-based search avoids exploring paths that aren't needed.
  3. Full reduce op support: All reduce op types (sum, avg, min, max, etc.) are now supported for R->P transitions, not just sum and avg.

[ghstack-poisoned]
@wconstab wconstab changed the title WIP fix error on redistribute_cost for R->Pavg, not sure i want to land this [DTensor] Make RedistributionPlanner handle all partials Jan 26, 2026
wconstab added a commit that referenced this pull request Jan 26, 2026
Previously, the planner hardcodes psum and ignores other partials.  This means if we tried to redistribute to pavg or pmax we'd fail.

  Changes Made

  1. torch/distributed/tensor/_redistribute.py

  Added partial_reduce_ops_in_target field (line 313):
  - Added a new instance variable partial_reduce_ops_in_target: set[str] = set() to track which Partial reduce ops are present in the src/dst placements.

  Modified reduce op collection (lines 749-754):
  - Added code to collect Partial reduce ops from both src and dst placements when planning the redistribution. This ensures only relevant reduce ops are considered.

  Updated R->P transition generation (lines 536-552):
  - Changed the hardcoded ("sum", "avg") to use self.partial_reduce_ops_in_target, which dynamically considers only the reduce ops present in the src/dst placements.

  2. test/distributed/tensor/test_redistribute.py

  Added test_replicate_to_partial_different_reduce_ops (lines 903-950):
  - Tests that R->P transitions work correctly for all reduce op types (sum, avg, min, max).
  - Verifies the local tensor content is correct based on the reduce_op semantics.

  Added test_replicate_to_partial_planner_reduce_op_collection (lines 952-1054):
  - Tests that the planner correctly collects reduce ops from src/dst placements.
  - Verifies the optimization that avoids naively expanding the graph to include all reduce op types.
  - Tests three scenarios: R->P("min"), P("max")->R, and multi-dimensional meshes with multiple Partial types.

  Key Benefits

  1. Dynamic reduce op handling: The planner now considers only reduce ops present in the actual redistribution request, rather than hardcoding specific reduce ops.
  2. No unnecessary graph expansion: By only considering relevant reduce ops, the graph-based search avoids exploring paths that aren't needed.
  3. Full reduce op support: All reduce op types (sum, avg, min, max, etc.) are now supported for R->P transitions, not just sum and avg.

ghstack-source-id: 46e012e
Pull Request resolved: #172479
@wconstab wconstab requested a review from zpcore January 26, 2026 20:09
# present in the redistribution, avoiding unnecessary graph expansion.
for placement in itertools.chain(src_placements, dst_placements):
if isinstance(placement, Partial):
self.partial_reduce_ops_in_target.add(placement.reduce_op)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we only support sum and avg for now? Or the order will matter and generate wrong result. For now, we can error out if there is Partial min, max etc in the src/dest placement.

I think in order to support the order for Partial, it should be roughly something like below, where we map Partial to [(Partial type, mesh dim), ... ]:

Partial: [(sum, 1), (max, 0), (sum, 2)]

, and apply the push/pop rule. This can be a future work.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm confused. you remember that we agreed to ban mixed partials right? #172609

Maybe your point is still valid- if we have Pmax in src and Psum in dst, i guess that is not strictly a 'mixed partial' situation. Will this situation require careful ordering in the graph search? Let's figure out a good test case to use for this.

Copy link
Copy Markdown
Member

@zpcore zpcore Jan 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, "Pmax in src and Psum in dst" will cause mixed Partial case during redistribution. Even though the src and dst doesn't have mixed Partial.

Will this situation require careful ordering in the graph search?

I think so. Maybe the easiest way is to support only sum and avg in src/dst placements for now.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤔 but don't we already support redistributing e.g. R->PMax in redistribute_local_tensor via greedy path? Then this leaves an implementation gap in the planner?

Copy link
Copy Markdown
Member

@zpcore zpcore Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is the gap. In greedy generated path, we will never have the mixed Partial during redistribution even though the src contains Partial(sum) and the dst contains Partial(max).

OK, I have a simple way to bridge the gap for graph based solution: Forbid entering a state with placements contains mixed Partial (https://github.com/pytorch/pytorch/pull/172479/changes#r2729869743). Then we can support Partial(max) in the dst placement! And for sure we can find a path without mixed Partial as long as the src and dst doesn't contain mixed Partial.

)
for reduce_op in self.partial_reduce_ops_in_target:
new_placements = list(placements)
new_placements[mesh_dim] = Partial(reduce_op)
Copy link
Copy Markdown
Member

@zpcore zpcore Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can prevent mixed Partial here:

if len(set(p for p in new_placements if p.is_partial()))>1:
    continue

In this way we can have different Partial in src and dst.

Previously, the planner hardcodes psum and ignores other partials.  This means if we tried to redistribute to pavg or pmax we'd fail.

  Changes Made

  1. torch/distributed/tensor/_redistribute.py

  Added partial_reduce_ops_in_target field (line 313):
  - Added a new instance variable partial_reduce_ops_in_target: set[str] = set() to track which Partial reduce ops are present in the src/dst placements.

  Modified reduce op collection (lines 749-754):
  - Added code to collect Partial reduce ops from both src and dst placements when planning the redistribution. This ensures only relevant reduce ops are considered.

  Updated R->P transition generation (lines 536-552):
  - Changed the hardcoded ("sum", "avg") to use self.partial_reduce_ops_in_target, which dynamically considers only the reduce ops present in the src/dst placements.

  2. test/distributed/tensor/test_redistribute.py

  Added test_replicate_to_partial_different_reduce_ops (lines 903-950):
  - Tests that R->P transitions work correctly for all reduce op types (sum, avg, min, max).
  - Verifies the local tensor content is correct based on the reduce_op semantics.

  Added test_replicate_to_partial_planner_reduce_op_collection (lines 952-1054):
  - Tests that the planner correctly collects reduce ops from src/dst placements.
  - Verifies the optimization that avoids naively expanding the graph to include all reduce op types.
  - Tests three scenarios: R->P("min"), P("max")->R, and multi-dimensional meshes with multiple Partial types.

  Key Benefits

  1. Dynamic reduce op handling: The planner now considers only reduce ops present in the actual redistribution request, rather than hardcoding specific reduce ops.
  2. No unnecessary graph expansion: By only considering relevant reduce ops, the graph-based search avoids exploring paths that aren't needed.
  3. Full reduce op support: All reduce op types (sum, avg, min, max, etc.) are now supported for R->P transitions, not just sum and avg.

[ghstack-poisoned]
Previously, the planner hardcodes psum and ignores other partials.  This means if we tried to redistribute to pavg or pmax we'd fail.

  Changes Made

  1. torch/distributed/tensor/_redistribute.py

  Added partial_reduce_ops_in_target field (line 313):
  - Added a new instance variable partial_reduce_ops_in_target: set[str] = set() to track which Partial reduce ops are present in the src/dst placements.

  Modified reduce op collection (lines 749-754):
  - Added code to collect Partial reduce ops from both src and dst placements when planning the redistribution. This ensures only relevant reduce ops are considered.

  Updated R->P transition generation (lines 536-552):
  - Changed the hardcoded ("sum", "avg") to use self.partial_reduce_ops_in_target, which dynamically considers only the reduce ops present in the src/dst placements.

  2. test/distributed/tensor/test_redistribute.py

  Added test_replicate_to_partial_different_reduce_ops (lines 903-950):
  - Tests that R->P transitions work correctly for all reduce op types (sum, avg, min, max).
  - Verifies the local tensor content is correct based on the reduce_op semantics.

  Added test_replicate_to_partial_planner_reduce_op_collection (lines 952-1054):
  - Tests that the planner correctly collects reduce ops from src/dst placements.
  - Verifies the optimization that avoids naively expanding the graph to include all reduce op types.
  - Tests three scenarios: R->P("min"), P("max")->R, and multi-dimensional meshes with multiple Partial types.

  Key Benefits

  1. Dynamic reduce op handling: The planner now considers only reduce ops present in the actual redistribution request, rather than hardcoding specific reduce ops.
  2. No unnecessary graph expansion: By only considering relevant reduce ops, the graph-based search avoids exploring paths that aren't needed.
  3. Full reduce op support: All reduce op types (sum, avg, min, max, etc.) are now supported for R->P transitions, not just sum and avg.

[ghstack-poisoned]
Copy link
Copy Markdown
Member

@zpcore zpcore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@wconstab
Copy link
Copy Markdown
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot Bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jan 27, 2026
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Jan 28, 2026
Inplace ops for dtensor have a restriction: you're not allowed to
redistribute the 'inplace' tensor.  This means in some cases, sharding
propagation has to fail becuase the inplace input is not compatible with
any of the possible sharding strategies.

This PR makes sure this case raises the expected informative error
rather than a confusing error about selecting min cost over an
empty sharding strategies list.
Pull Request resolved: #173572
Approved by: https://github.com/pianpwk
ghstack dependencies: #172479
pytorchmergebot pushed a commit that referenced this pull request Jan 28, 2026
kapilsh pushed a commit to kapilsh/pytorch that referenced this pull request Feb 2, 2026
@github-actions github-actions Bot deleted the gh/wconstab/501/head branch February 27, 2026 02:24
sandy-gags pushed a commit to sandy-gags/pytorch that referenced this pull request Mar 12, 2026
Previously, the planner hardcodes psum and ignores other partials.  This means if we tried to redistribute to pavg or pmax we'd fail.

  Changes Made

  1. torch/distributed/tensor/_redistribute.py

  Added partial_reduce_ops_in_target field (line 313):
  - Added a new instance variable partial_reduce_ops_in_target: set[str] = set() to track which Partial reduce ops are present in the src/dst placements.

  Modified reduce op collection (lines 749-754):
  - Added code to collect Partial reduce ops from both src and dst placements when planning the redistribution. This ensures only relevant reduce ops are considered.

  Updated R->P transition generation (lines 536-552):
  - Changed the hardcoded ("sum", "avg") to use self.partial_reduce_ops_in_target, which dynamically considers only the reduce ops present in the src/dst placements.

  2. test/distributed/tensor/test_redistribute.py

  Added test_replicate_to_partial_different_reduce_ops (lines 903-950):
  - Tests that R->P transitions work correctly for all reduce op types (sum, avg, min, max).
  - Verifies the local tensor content is correct based on the reduce_op semantics.

  Added test_replicate_to_partial_planner_reduce_op_collection (lines 952-1054):
  - Tests that the planner correctly collects reduce ops from src/dst placements.
  - Verifies the optimization that avoids naively expanding the graph to include all reduce op types.
  - Tests three scenarios: R->P("min"), P("max")->R, and multi-dimensional meshes with multiple Partial types.

  Key Benefits

  1. Dynamic reduce op handling: The planner now considers only reduce ops present in the actual redistribution request, rather than hardcoding specific reduce ops.
  2. No unnecessary graph expansion: By only considering relevant reduce ops, the graph-based search avoids exploring paths that aren't needed.
  3. Full reduce op support: All reduce op types (sum, avg, min, max, etc.) are now supported for R->P transitions, not just sum and avg.

ghstack-source-id: e40beab
Pull Request resolved: pytorch/pytorch#172479
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: distributed (dtensor) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants