Skip to content

[DTensor] Single Dim Strategy infra#167677

Closed
wconstab wants to merge 28 commits intogh/wconstab/456/basefrom
gh/wconstab/456/head
Closed

[DTensor] Single Dim Strategy infra#167677
wconstab wants to merge 28 commits intogh/wconstab/456/basefrom
gh/wconstab/456/head

Conversation

@wconstab
Copy link
Contributor

@wconstab wconstab commented Nov 12, 2025

Stack from ghstack (oldest at bottom):

Motivation

We find it is too difficult to implement sharding strategies for DTensor, and this slows progress towards full operator coverage, and increases likelihood of bugs in sharding rules. We also expect to add new Placement types (e.g. more complete support for StridedShard in the near term, but possibly others as well), and the current formulation of sharding strategies is not scalable to adding new placement types.

A primary reason it's so difficult to write sharding rules today is that they combine 2 things: (a) a mathematical description of which in/out shardings are correct for the op, (b) premature runtime optimization to avoid considering too many combinations over N-D mesh.

The proposal is to remove (b) and focus just on (a).

tl;dr

  1. Write sharding prop rules in terms of 1 mesh dim and using a 'placeholder' for generic sharding types
    e.g. matmul rule returns:
[
   ShardPlaceholder(0), Replicate() -> ShardPlaceholder(0),
   Replicate(), ShardPlaceholder(1), -> ShardPlaceholder(1),
   ShardPlaceholder(1), ShardPlaceholder(0) -> Partial()
]
  1. After registration, each rule gets automatically expanded to include the real sharding types discovered in the inputs at runtime, and add 'full replication' rule.
    e.g.
    if inputs are fully replicated, we drop the placeholder rules and only use
[
   Replicate(), Replicate() -> Replicate()
]

if 'Shard' is discovered in inputs, we fill placeholders like

[
   Shard(0), Replicate() -> Shard(0),
   Replicate(), Shard(1), -> Shard(1),
   Shard(1), Shard(0) -> Partial(),
   Replicate(), Replicate() -> Replicate()
]

if 'Shard' and 'StridedShard' are both discovered in the inputs, we expand to

[
   Shard(0), Replicate() -> Shard(0),
   Replicate(), Shard(1), -> Shard(1),
   Shard(1), Shard(0) -> Partial(),
   StridedShard(0), Replicate() -> StridedShard(0),
   Replicate(), StridedShard(1), -> StridedShard(1),
   StridedShard(1), StridedShard(0) -> Partial(),
   Replicate(), Replicate() -> Replicate()
]
  1. After filling the placeholders, we expand to N-D mesh and find the minimum cost
    (a) full enumeration via itertools.product is implemented and gives exact parity with rules like 'einsum' today
    (b) optimized solution, starting from input placements and iterating in the order of increasing cost until reaching a min-cost solution without having to fully enumerate - under development/prototyping

This PR

  • defines a 'single_dim strategy' function and a ShardingPlaceholder
  • adds a util for expanding a single_dim strategy into a regular strategy
  • supports StridedShard automatically via ShardingPlaceholder expansion
  • writes rules for mm and cat and uses unit tests to validate their expansion

Next Steps (PR stack)

  • Support pointwise and foreach ops in the single_dim infra
  • Hook up single-dim strategies to sharding_prop (op registration)
  • Start to use single_dim rules to replace existing rules
  • Improve the runtime of searching the fully expanded strategy
  • Explore using decomps together with single-dim rules to support more operators

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @d4l3k @pragupta @msaroufim @dcci @tianyu-l @XilunWu @SherlockNoMad @ppwwyyxx

WIP for now, not ready to land

Experimenting with the idea of decomposing op strategies into
 - a core function that proposes a single-mesh-dim strategy
 - automatic expansion to the mesh inside the registration mechanism

Also, plan to add a 'ShardingPlacholder' to use for writing
single-mesh-dim strategies in a way that can be expanded at runtime to
any type of sharding discovered in the inputs.

For now, this relies on full enumeration of the single-dim strategy onto
the full mesh, and full enumeration of the combinations of different
sharding placements discovered in the input, but we should be able to
replace this with an algorithm to expand iteratively following the path
of lowest redistribution cost.

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 12, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/167677

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 1 Pending

As of commit cbef950 with merge base 1984725 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/inductor oncall: distributed Add this issue/PR to distributed oncall triage queue labels Nov 12, 2025
wconstab added a commit that referenced this pull request Nov 12, 2025
WIP for now, not ready to land

Experimenting with the idea of decomposing op strategies into
 - a core function that proposes a single-mesh-dim strategy
 - automatic expansion to the mesh inside the registration mechanism

Also, plan to add a 'ShardingPlacholder' to use for writing
single-mesh-dim strategies in a way that can be expanded at runtime to
any type of sharding discovered in the inputs.

For now, this relies on full enumeration of the single-dim strategy onto
the full mesh, and full enumeration of the combinations of different
sharding placements discovered in the input, but we should be able to
replace this with an algorithm to expand iteratively following the path
of lowest redistribution cost.

ghstack-source-id: ed6977e
Pull Request resolved: #167677
@wconstab wconstab added module: dtensor distributed tensor tag release notes: distributed (dtensor) release notes category labels Nov 12, 2025
@ezyang
Copy link
Contributor

ezyang commented Nov 13, 2025

One of the things that seemed important to me when doing this work, is to making the testing is actually enough to detect problems. @tianyu-l also had the idea of fuzzing through input placements to try to "discover" if sharding rules were complete or not.

wconstab added a commit that referenced this pull request Nov 14, 2025
WIP for now, not ready to land

Experimenting with the idea of decomposing op strategies into
 - a core function that proposes a single-mesh-dim strategy
 - automatic expansion to the mesh inside the registration mechanism

Also, plan to add a 'ShardingPlacholder' to use for writing
single-mesh-dim strategies in a way that can be expanded at runtime to
any type of sharding discovered in the inputs.

For now, this relies on full enumeration of the single-dim strategy onto
the full mesh, and full enumeration of the combinations of different
sharding placements discovered in the input, but we should be able to
replace this with an algorithm to expand iteratively following the path
of lowest redistribution cost.

ghstack-source-id: ed6977e
Pull Request resolved: #167677
to cover all the mesh dims present in the runtime inputs.
"""

def expanded_registration_wrapper(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function seems overlapping largely with expand_to_full_mesh_op_strategy, but it sounds you are tackling a different / harder problem?

def expand_to_full_mesh_op_strategy(

In general, I found the PR summary is a bit hard to follow (likely because I missed some meetings / discussions), do we have an RFC or meeting notes I could follow?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added some details to the PR desc. There isn't an actual RFC but this can serve as a mini-RFC for now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function seems overlapping largely with expand_to_full_mesh_op_strategy, but it sounds you are tackling a different / harder problem?

I have addressed this. I was able to use the original expand helper.

wconstab added a commit that referenced this pull request Nov 18, 2025
WIP for now, not ready to land

Experimenting with the idea of decomposing op strategies into
 - a core function that proposes a single-mesh-dim strategy
 - automatic expansion to the mesh inside the registration mechanism

Also, plan to add a 'ShardingPlacholder' to use for writing
single-mesh-dim strategies in a way that can be expanded at runtime to
any type of sharding discovered in the inputs.

For now, this relies on full enumeration of the single-dim strategy onto
the full mesh, and full enumeration of the combinations of different
sharding placements discovered in the input, but we should be able to
replace this with an algorithm to expand iteratively following the path
of lowest redistribution cost.

ghstack-source-id: ed6977e
Pull Request resolved: #167677
wconstab added a commit that referenced this pull request Nov 18, 2025
WIP for now, not ready to land

Experimenting with the idea of decomposing op strategies into
 - a core function that proposes a single-mesh-dim strategy
 - automatic expansion to the mesh inside the registration mechanism

Also, plan to add a 'ShardingPlacholder' to use for writing
single-mesh-dim strategies in a way that can be expanded at runtime to
any type of sharding discovered in the inputs.

For now, this relies on full enumeration of the single-dim strategy onto
the full mesh, and full enumeration of the combinations of different
sharding placements discovered in the input, but we should be able to
replace this with an algorithm to expand iteratively following the path
of lowest redistribution cost.

ghstack-source-id: ed6977e
Pull Request resolved: #167677
WIP for now, not ready to land

Experimenting with the idea of decomposing op strategies into
 - a core function that proposes a single-mesh-dim strategy
 - automatic expansion to the mesh inside the registration mechanism

Also, plan to add a 'ShardingPlacholder' to use for writing
single-mesh-dim strategies in a way that can be expanded at runtime to
any type of sharding discovered in the inputs.

For now, this relies on full enumeration of the single-dim strategy onto
the full mesh, and full enumeration of the combinations of different
sharding placements discovered in the input, but we should be able to
replace this with an algorithm to expand iteratively following the path
of lowest redistribution cost.

cc H-Huang awgu wanchaol fegin fduwjj wz337 d4l3k pragupta msaroufim dcci tianyu-l XilunWu SherlockNoMad

[ghstack-poisoned]
WIP for now, not ready to land

Experimenting with the idea of decomposing op strategies into
 - a core function that proposes a single-mesh-dim strategy
 - automatic expansion to the mesh inside the registration mechanism

Also, plan to add a 'ShardingPlacholder' to use for writing
single-mesh-dim strategies in a way that can be expanded at runtime to
any type of sharding discovered in the inputs.

For now, this relies on full enumeration of the single-dim strategy onto
the full mesh, and full enumeration of the combinations of different
sharding placements discovered in the input, but we should be able to
replace this with an algorithm to expand iteratively following the path
of lowest redistribution cost.

[ghstack-poisoned]
WIP for now, not ready to land

Experimenting with the idea of decomposing op strategies into
 - a core function that proposes a single-mesh-dim strategy
 - automatic expansion to the mesh inside the registration mechanism

Also, plan to add a 'ShardingPlacholder' to use for writing
single-mesh-dim strategies in a way that can be expanded at runtime to
any type of sharding discovered in the inputs.

For now, this relies on full enumeration of the single-dim strategy onto
the full mesh, and full enumeration of the combinations of different
sharding placements discovered in the input, but we should be able to
replace this with an algorithm to expand iteratively following the path
of lowest redistribution cost.

[ghstack-poisoned]
**This diff stack is a prototype/experiment.**

### Motivation
We find it is too difficult to implement sharding strategies for DTensor, and this slows progress towards full operator coverage, and increases likelihood of bugs in sharding rules.  We also expect to add new Placement types (e.g. more complete support for StridedShard in the near term, but possibly others as well), and the current formulation of sharding strategies is not scalable to adding new placement types.

### Specific Goals
1. [**not prototyped**]Op Strategies should automatically support StridedShard and any other 'purely sharding' placement type without additional work.  Use a 'sharding placeholder' to declare sharding rules (e.g. S(0) , S(0) -> S(0) applies equally well to Shard(0) or StridedShard(0, sf=2) - we can make this substitution outside of the per-op rule.
a. constraints such as 'must evenly shard' can be moved onto the variable, and filtered against real tensor/mesh dims during expansion time
3. [**prototyped**]Focus Op Strategies on a single mesh-dim, and rely on standard infra to expand to the full mesh
4. [**prototyped**]Create the possible input-output sharding specs without looking at input args and doing premature filtering, which trades code complexity for runtime optimization.  Try to regain the runtime savings in another common infra layer during 'expansion' to the mesh

### Experiment Approach
This stack implements some of the above goals, and then implements some operators (so far, mm and pointwise).  I'm working on figuring out what the right tests and metrics to collect are before going further.  If you have concerns about whether this approach would work for a particular op, i could try that.
* One major difference in this proposal is the 'premature optimization' in the strategy is removed.  How much does this matter?  Can we recover some of the time by doing filtering during mesh expansion, or by changing to expand only to some 'min cost' set of states around the input placements?
* How can we test if the new strategy is correct? Can we tell if we're missing a strategy or if we're issuing a wrong strategy?

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Nov 20, 2025
WIP for now, not ready to land

Experimenting with the idea of decomposing op strategies into
 - a core function that proposes a single-mesh-dim strategy
 - automatic expansion to the mesh inside the registration mechanism

Also, plan to add a 'ShardingPlacholder' to use for writing
single-mesh-dim strategies in a way that can be expanded at runtime to
any type of sharding discovered in the inputs.

For now, this relies on full enumeration of the single-dim strategy onto
the full mesh, and full enumeration of the combinations of different
sharding placements discovered in the input, but we should be able to
replace this with an algorithm to expand iteratively following the path
of lowest redistribution cost.

ghstack-source-id: 6c772b3
Pull Request resolved: #167677
**This diff stack is a prototype/experiment.**

### Motivation
We find it is too difficult to implement sharding strategies for DTensor, and this slows progress towards full operator coverage, and increases likelihood of bugs in sharding rules.  We also expect to add new Placement types (e.g. more complete support for StridedShard in the near term, but possibly others as well), and the current formulation of sharding strategies is not scalable to adding new placement types.

### Specific Goals
1. [**not prototyped**]Op Strategies should automatically support StridedShard and any other 'purely sharding' placement type without additional work.  Use a 'sharding placeholder' to declare sharding rules (e.g. S(0) , S(0) -> S(0) applies equally well to Shard(0) or StridedShard(0, sf=2) - we can make this substitution outside of the per-op rule.
a. constraints such as 'must evenly shard' can be moved onto the variable, and filtered against real tensor/mesh dims during expansion time
3. [**prototyped**]Focus Op Strategies on a single mesh-dim, and rely on standard infra to expand to the full mesh
4. [**prototyped**]Create the possible input-output sharding specs without looking at input args and doing premature filtering, which trades code complexity for runtime optimization.  Try to regain the runtime savings in another common infra layer during 'expansion' to the mesh

### Experiment Approach
This stack implements some of the above goals, and then implements some operators (so far, mm and pointwise).  I'm working on figuring out what the right tests and metrics to collect are before going further.  If you have concerns about whether this approach would work for a particular op, i could try that.
* One major difference in this proposal is the 'premature optimization' in the strategy is removed.  How much does this matter?  Can we recover some of the time by doing filtering during mesh expansion, or by changing to expand only to some 'min cost' set of states around the input placements?
* How can we test if the new strategy is correct? Can we tell if we're missing a strategy or if we're issuing a wrong strategy?

[ghstack-poisoned]
**This diff stack is a prototype/experiment.**

### Motivation
We find it is too difficult to implement sharding strategies for DTensor, and this slows progress towards full operator coverage, and increases likelihood of bugs in sharding rules.  We also expect to add new Placement types (e.g. more complete support for StridedShard in the near term, but possibly others as well), and the current formulation of sharding strategies is not scalable to adding new placement types.

### Specific Goals
1. [**not prototyped**]Op Strategies should automatically support StridedShard and any other 'purely sharding' placement type without additional work.  Use a 'sharding placeholder' to declare sharding rules (e.g. S(0) , S(0) -> S(0) applies equally well to Shard(0) or StridedShard(0, sf=2) - we can make this substitution outside of the per-op rule.
a. constraints such as 'must evenly shard' can be moved onto the variable, and filtered against real tensor/mesh dims during expansion time
3. [**prototyped**]Focus Op Strategies on a single mesh-dim, and rely on standard infra to expand to the full mesh
4. [**prototyped**]Create the possible input-output sharding specs without looking at input args and doing premature filtering, which trades code complexity for runtime optimization.  Try to regain the runtime savings in another common infra layer during 'expansion' to the mesh

### Experiment Approach
This stack implements some of the above goals, and then implements some operators (so far, mm and pointwise).  I'm working on figuring out what the right tests and metrics to collect are before going further.  If you have concerns about whether this approach would work for a particular op, i could try that.
* One major difference in this proposal is the 'premature optimization' in the strategy is removed.  How much does this matter?  Can we recover some of the time by doing filtering during mesh expansion, or by changing to expand only to some 'min cost' set of states around the input placements?
* How can we test if the new strategy is correct? Can we tell if we're missing a strategy or if we're issuing a wrong strategy?

[ghstack-poisoned]
**This diff stack is a prototype/experiment.**

### Motivation
We find it is too difficult to implement sharding strategies for DTensor, and this slows progress towards full operator coverage, and increases likelihood of bugs in sharding rules.  We also expect to add new Placement types (e.g. more complete support for StridedShard in the near term, but possibly others as well), and the current formulation of sharding strategies is not scalable to adding new placement types.

### Specific Goals
1. [**not prototyped**]Op Strategies should automatically support StridedShard and any other 'purely sharding' placement type without additional work.  Use a 'sharding placeholder' to declare sharding rules (e.g. S(0) , S(0) -> S(0) applies equally well to Shard(0) or StridedShard(0, sf=2) - we can make this substitution outside of the per-op rule.
a. constraints such as 'must evenly shard' can be moved onto the variable, and filtered against real tensor/mesh dims during expansion time
3. [**prototyped**]Focus Op Strategies on a single mesh-dim, and rely on standard infra to expand to the full mesh
4. [**prototyped**]Create the possible input-output sharding specs without looking at input args and doing premature filtering, which trades code complexity for runtime optimization.  Try to regain the runtime savings in another common infra layer during 'expansion' to the mesh

### Experiment Approach
This stack implements some of the above goals, and then implements some operators (so far, mm and pointwise).  I'm working on figuring out what the right tests and metrics to collect are before going further.  If you have concerns about whether this approach would work for a particular op, i could try that.
* One major difference in this proposal is the 'premature optimization' in the strategy is removed.  How much does this matter?  Can we recover some of the time by doing filtering during mesh expansion, or by changing to expand only to some 'min cost' set of states around the input placements?
* How can we test if the new strategy is correct? Can we tell if we're missing a strategy or if we're issuing a wrong strategy?

[ghstack-poisoned]
**This diff stack is a prototype/experiment.**

### Motivation
We find it is too difficult to implement sharding strategies for DTensor, and this slows progress towards full operator coverage, and increases likelihood of bugs in sharding rules.  We also expect to add new Placement types (e.g. more complete support for StridedShard in the near term, but possibly others as well), and the current formulation of sharding strategies is not scalable to adding new placement types.

### Specific Goals
1. [**not prototyped**]Op Strategies should automatically support StridedShard and any other 'purely sharding' placement type without additional work.  Use a 'sharding placeholder' to declare sharding rules (e.g. S(0) , S(0) -> S(0) applies equally well to Shard(0) or StridedShard(0, sf=2) - we can make this substitution outside of the per-op rule.
a. constraints such as 'must evenly shard' can be moved onto the variable, and filtered against real tensor/mesh dims during expansion time
3. [**prototyped**]Focus Op Strategies on a single mesh-dim, and rely on standard infra to expand to the full mesh
4. [**prototyped**]Create the possible input-output sharding specs without looking at input args and doing premature filtering, which trades code complexity for runtime optimization.  Try to regain the runtime savings in another common infra layer during 'expansion' to the mesh

### Experiment Approach
This stack implements some of the above goals, and then implements some operators (so far, mm and pointwise).  I'm working on figuring out what the right tests and metrics to collect are before going further.  If you have concerns about whether this approach would work for a particular op, i could try that.
* One major difference in this proposal is the 'premature optimization' in the strategy is removed.  How much does this matter?  Can we recover some of the time by doing filtering during mesh expansion, or by changing to expand only to some 'min cost' set of states around the input placements?
* How can we test if the new strategy is correct? Can we tell if we're missing a strategy or if we're issuing a wrong strategy?

[ghstack-poisoned]
**This diff stack is a prototype/experiment.**

### Motivation
We find it is too difficult to implement sharding strategies for DTensor, and this slows progress towards full operator coverage, and increases likelihood of bugs in sharding rules.  We also expect to add new Placement types (e.g. more complete support for StridedShard in the near term, but possibly others as well), and the current formulation of sharding strategies is not scalable to adding new placement types.

### Specific Goals
1. [**not prototyped**]Op Strategies should automatically support StridedShard and any other 'purely sharding' placement type without additional work.  Use a 'sharding placeholder' to declare sharding rules (e.g. S(0) , S(0) -> S(0) applies equally well to Shard(0) or StridedShard(0, sf=2) - we can make this substitution outside of the per-op rule.
a. constraints such as 'must evenly shard' can be moved onto the variable, and filtered against real tensor/mesh dims during expansion time
3. [**prototyped**]Focus Op Strategies on a single mesh-dim, and rely on standard infra to expand to the full mesh
4. [**prototyped**]Create the possible input-output sharding specs without looking at input args and doing premature filtering, which trades code complexity for runtime optimization.  Try to regain the runtime savings in another common infra layer during 'expansion' to the mesh

### Experiment Approach
This stack implements some of the above goals, and then implements some operators (so far, mm and pointwise).  I'm working on figuring out what the right tests and metrics to collect are before going further.  If you have concerns about whether this approach would work for a particular op, i could try that.
* One major difference in this proposal is the 'premature optimization' in the strategy is removed.  How much does this matter?  Can we recover some of the time by doing filtering during mesh expansion, or by changing to expand only to some 'min cost' set of states around the input placements?
* How can we test if the new strategy is correct? Can we tell if we're missing a strategy or if we're issuing a wrong strategy?

[ghstack-poisoned]
**This diff stack is a prototype/experiment.**

### Motivation
We find it is too difficult to implement sharding strategies for DTensor, and this slows progress towards full operator coverage, and increases likelihood of bugs in sharding rules.  We also expect to add new Placement types (e.g. more complete support for StridedShard in the near term, but possibly others as well), and the current formulation of sharding strategies is not scalable to adding new placement types.

### Specific Goals
1. [**prototyped**]Op Strategies should automatically support StridedShard and any other 'purely sharding' placement type without additional work.  Use a 'sharding placeholder' to declare sharding rules (e.g. S(0) , S(0) -> S(0) applies equally well to Shard(0) or StridedShard(0, sf=2) - we can make this substitution outside of the per-op rule.
a. constraints such as 'must evenly shard' can be moved onto the variable, and filtered against real tensor/mesh dims during expansion time
3. [**prototyped**]Focus Op Strategies on a single mesh-dim, and rely on standard infra to expand to the full mesh
4. [**prototyped**]Create the possible input-output sharding specs without looking at input args and doing premature filtering, which trades code complexity for runtime optimization.  Try to regain the runtime savings in another common infra layer during 'expansion' to the mesh
4. [**prototyped**]Utilize a min-redistribution-cost priority queue to guide the expansion of the strategy and stop expansion after finding the first (lowest cost) redistribution of the inputs that matches an expansion of the strategy

### Experiment Approach
This stack implements some of the above goals, and then implements some operators (so far, mm and pointwise).  I'm working on figuring out what the right tests and metrics to collect are before going further.  If you have concerns about whether this approach would work for a particular op, i could try that.
* One major difference in this proposal is the 'premature optimization' in the strategy is removed.  How much does this matter?  Can we recover some of the time by doing filtering during mesh expansion, or by changing to expand only to some 'min cost' set of states around the input placements?
 -> Some preliminary benchmark results: https://gist.github.com/wconstab/5df63815696c504db2ffadfbf1675d21
* How can we test if the new strategy is correct? Can we tell if we're missing a strategy or if we're issuing a wrong strategy?

[ghstack-poisoned]
**This diff stack is a prototype/experiment.**

### Motivation
We find it is too difficult to implement sharding strategies for DTensor, and this slows progress towards full operator coverage, and increases likelihood of bugs in sharding rules.  We also expect to add new Placement types (e.g. more complete support for StridedShard in the near term, but possibly others as well), and the current formulation of sharding strategies is not scalable to adding new placement types.

### Specific Goals
1. [**prototyped**]Op Strategies should automatically support StridedShard and any other 'purely sharding' placement type without additional work.  Use a 'sharding placeholder' to declare sharding rules (e.g. S(0) , S(0) -> S(0) applies equally well to Shard(0) or StridedShard(0, sf=2) - we can make this substitution outside of the per-op rule.
a. constraints such as 'must evenly shard' can be moved onto the variable, and filtered against real tensor/mesh dims during expansion time
3. [**prototyped**]Focus Op Strategies on a single mesh-dim, and rely on standard infra to expand to the full mesh
4. [**prototyped**]Create the possible input-output sharding specs without looking at input args and doing premature filtering, which trades code complexity for runtime optimization.  Try to regain the runtime savings in another common infra layer during 'expansion' to the mesh
4. [**prototyped**]Utilize a min-redistribution-cost priority queue to guide the expansion of the strategy and stop expansion after finding the first (lowest cost) redistribution of the inputs that matches an expansion of the strategy

### Experiment Approach
This stack implements some of the above goals, and then implements some operators (so far, mm and pointwise).  I'm working on figuring out what the right tests and metrics to collect are before going further.  If you have concerns about whether this approach would work for a particular op, i could try that.
* One major difference in this proposal is the 'premature optimization' in the strategy is removed.  How much does this matter?  Can we recover some of the time by doing filtering during mesh expansion, or by changing to expand only to some 'min cost' set of states around the input placements?
 -> Some preliminary benchmark results: https://gist.github.com/wconstab/5df63815696c504db2ffadfbf1675d21
* How can we test if the new strategy is correct? Can we tell if we're missing a strategy or if we're issuing a wrong strategy?

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Dec 18, 2025
ghstack-source-id: 10b8b09
Pull Request resolved: #167677
@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #170359

weifengpy pushed a commit that referenced this pull request Dec 19, 2025
ghstack-source-id: 671c956
Pull Request resolved: #167677
weifengpy pushed a commit that referenced this pull request Dec 19, 2025
ghstack-source-id: 671c956
Pull Request resolved: #167677
pytorchmergebot pushed a commit that referenced this pull request Dec 19, 2025
This PR adds the register_single_dim_strategy util,  and hooks it up to sharding_propagator.  It also tests the registration.

Notes:
* I didn't yet decide how multiple registrations should be handled.  I was planning to make it an error if you register twice for the same op for either single_dim or regular strategies.
* I took the cleanest path of integration for now in sharding_prop, reusing as much code as possible with the existing 'op_strategy' case.  I may have to change this later when integrating find_min_cost

Pull Request resolved: #170359
Approved by: https://github.com/weifengpy
ghstack dependencies: #170615, #167677

Co-authored-by: Pian Pawakapan <pianpwk@meta.com>
majing921201 pushed a commit to majing921201/pytorch that referenced this pull request Dec 19, 2025
This reverts commit 32d0782.

Reverted pytorch#170359 on behalf of https://github.com/jeanschmidt due to Required to revert pytorch#167677 that is required to revert pytorch#170615 that is required to revert pytorch#170030 ([comment](pytorch#170359 (comment)))
majing921201 pushed a commit to majing921201/pytorch that referenced this pull request Dec 19, 2025
This reverts commit c3e628e.

Reverted pytorch#167677 on behalf of https://github.com/jeanschmidt due to Required to revert pytorch#170615 that is required to rever pytorch#170030 ([comment](pytorch#167677 (comment)))
majing921201 pushed a commit to majing921201/pytorch that referenced this pull request Dec 19, 2025
### Motivation
We find it is too difficult to implement sharding strategies for DTensor, and this slows progress towards full operator coverage, and increases likelihood of bugs in sharding rules.  We also expect to add new Placement types (e.g. more complete support for StridedShard in the near term, but possibly others as well), and the current formulation of sharding strategies is not scalable to adding new placement types.

A primary reason it's so difficult to write sharding rules today is that they combine 2 things: (a) a mathematical description of which in/out shardings are correct for the op, (b) premature runtime optimization to avoid considering too many combinations over N-D mesh.

The proposal is to remove (b) and focus just on (a).

### tl;dr
1. Write sharding prop rules in terms of 1 mesh dim and using a 'placeholder' for generic sharding types
e.g. matmul rule returns:
```
[
   ShardPlaceholder(0), Replicate() -> ShardPlaceholder(0),
   Replicate(), ShardPlaceholder(1), -> ShardPlaceholder(1),
   ShardPlaceholder(1), ShardPlaceholder(0) -> Partial()
]
```
2. After registration, each rule gets automatically expanded to include the real sharding types discovered in the inputs at runtime, and add 'full replication' rule.
e.g.
if inputs are fully replicated, we drop the placeholder rules and only use
```
[
   Replicate(), Replicate() -> Replicate()
]
```
if 'Shard' is discovered in inputs, we fill placeholders like
```
[
   Shard(0), Replicate() -> Shard(0),
   Replicate(), Shard(1), -> Shard(1),
   Shard(1), Shard(0) -> Partial(),
   Replicate(), Replicate() -> Replicate()
]
```
if 'Shard' and 'StridedShard' are both discovered in the inputs, we expand to
```
[
   Shard(0), Replicate() -> Shard(0),
   Replicate(), Shard(1), -> Shard(1),
   Shard(1), Shard(0) -> Partial(),
   StridedShard(0), Replicate() -> StridedShard(0),
   Replicate(), StridedShard(1), -> StridedShard(1),
   StridedShard(1), StridedShard(0) -> Partial(),
   Replicate(), Replicate() -> Replicate()
]
```
3. After filling the placeholders, we expand to N-D mesh and find the minimum cost
(a) full enumeration via itertools.product is implemented and gives exact parity with rules like 'einsum' today
(b) optimized solution, starting from input placements and iterating in the order of increasing cost until reaching a min-cost solution _without_ having to fully enumerate - under development/prototyping

### This PR
* defines a 'single_dim strategy' function and a ShardingPlaceholder
* adds a util for expanding a single_dim strategy into a regular strategy
* supports StridedShard automatically via ShardingPlaceholder expansion
* writes rules for mm and cat and uses unit tests to validate their expansion

### Next Steps (PR stack)
* Support pointwise and foreach ops in the single_dim infra
* Hook up single-dim strategies to sharding_prop (op registration)
* Start to use single_dim rules to replace existing rules
* Improve the runtime of searching the fully expanded strategy
* Explore using decomps together with single-dim rules to support more operators

Pull Request resolved: pytorch#167677
Approved by: https://github.com/weifengpy
ghstack dependencies: pytorch#170615

Co-authored-by: Pian Pawakapan <pianpwk@meta.com>
majing921201 pushed a commit to majing921201/pytorch that referenced this pull request Dec 19, 2025
This PR adds the register_single_dim_strategy util,  and hooks it up to sharding_propagator.  It also tests the registration.

Notes:
* I didn't yet decide how multiple registrations should be handled.  I was planning to make it an error if you register twice for the same op for either single_dim or regular strategies.
* I took the cleanest path of integration for now in sharding_prop, reusing as much code as possible with the existing 'op_strategy' case.  I may have to change this later when integrating find_min_cost

Pull Request resolved: pytorch#170359
Approved by: https://github.com/weifengpy
ghstack dependencies: pytorch#170615, pytorch#167677

Co-authored-by: Pian Pawakapan <pianpwk@meta.com>
pytorchmergebot pushed a commit that referenced this pull request Dec 19, 2025
Enforce tensor_meta is not none for new single-dim rules.

Allow tensor_meta to continue to be None for existing rules for now. We
should consider in the future asserting tensor_meta is required in
DTensorSpec, but for now we just try to limit the bleeding.
Pull Request resolved: #170827
Approved by: https://github.com/dolpm
ghstack dependencies: #170615, #167677, #170359
xgz2 pushed a commit that referenced this pull request Dec 22, 2025
This reverts commit 32d0782.

Reverted #170359 on behalf of https://github.com/jeanschmidt due to Required to revert #167677 that is required to revert #170615 that is required to revert #170030 ([comment](#170359 (comment)))
xgz2 pushed a commit that referenced this pull request Dec 22, 2025
This reverts commit c3e628e.

Reverted #167677 on behalf of https://github.com/jeanschmidt due to Required to revert #170615 that is required to rever #170030 ([comment](#167677 (comment)))
xgz2 pushed a commit that referenced this pull request Dec 22, 2025
### Motivation
We find it is too difficult to implement sharding strategies for DTensor, and this slows progress towards full operator coverage, and increases likelihood of bugs in sharding rules.  We also expect to add new Placement types (e.g. more complete support for StridedShard in the near term, but possibly others as well), and the current formulation of sharding strategies is not scalable to adding new placement types.

A primary reason it's so difficult to write sharding rules today is that they combine 2 things: (a) a mathematical description of which in/out shardings are correct for the op, (b) premature runtime optimization to avoid considering too many combinations over N-D mesh.

The proposal is to remove (b) and focus just on (a).

### tl;dr
1. Write sharding prop rules in terms of 1 mesh dim and using a 'placeholder' for generic sharding types
e.g. matmul rule returns:
```
[
   ShardPlaceholder(0), Replicate() -> ShardPlaceholder(0),
   Replicate(), ShardPlaceholder(1), -> ShardPlaceholder(1),
   ShardPlaceholder(1), ShardPlaceholder(0) -> Partial()
]
```
2. After registration, each rule gets automatically expanded to include the real sharding types discovered in the inputs at runtime, and add 'full replication' rule.
e.g.
if inputs are fully replicated, we drop the placeholder rules and only use
```
[
   Replicate(), Replicate() -> Replicate()
]
```
if 'Shard' is discovered in inputs, we fill placeholders like
```
[
   Shard(0), Replicate() -> Shard(0),
   Replicate(), Shard(1), -> Shard(1),
   Shard(1), Shard(0) -> Partial(),
   Replicate(), Replicate() -> Replicate()
]
```
if 'Shard' and 'StridedShard' are both discovered in the inputs, we expand to
```
[
   Shard(0), Replicate() -> Shard(0),
   Replicate(), Shard(1), -> Shard(1),
   Shard(1), Shard(0) -> Partial(),
   StridedShard(0), Replicate() -> StridedShard(0),
   Replicate(), StridedShard(1), -> StridedShard(1),
   StridedShard(1), StridedShard(0) -> Partial(),
   Replicate(), Replicate() -> Replicate()
]
```
3. After filling the placeholders, we expand to N-D mesh and find the minimum cost
(a) full enumeration via itertools.product is implemented and gives exact parity with rules like 'einsum' today
(b) optimized solution, starting from input placements and iterating in the order of increasing cost until reaching a min-cost solution _without_ having to fully enumerate - under development/prototyping

### This PR
* defines a 'single_dim strategy' function and a ShardingPlaceholder
* adds a util for expanding a single_dim strategy into a regular strategy
* supports StridedShard automatically via ShardingPlaceholder expansion
* writes rules for mm and cat and uses unit tests to validate their expansion

### Next Steps (PR stack)
* Support pointwise and foreach ops in the single_dim infra
* Hook up single-dim strategies to sharding_prop (op registration)
* Start to use single_dim rules to replace existing rules
* Improve the runtime of searching the fully expanded strategy
* Explore using decomps together with single-dim rules to support more operators

Pull Request resolved: #167677
Approved by: https://github.com/weifengpy
ghstack dependencies: #170615

Co-authored-by: Pian Pawakapan <pianpwk@meta.com>
xgz2 pushed a commit that referenced this pull request Dec 22, 2025
This PR adds the register_single_dim_strategy util,  and hooks it up to sharding_propagator.  It also tests the registration.

Notes:
* I didn't yet decide how multiple registrations should be handled.  I was planning to make it an error if you register twice for the same op for either single_dim or regular strategies.
* I took the cleanest path of integration for now in sharding_prop, reusing as much code as possible with the existing 'op_strategy' case.  I may have to change this later when integrating find_min_cost

Pull Request resolved: #170359
Approved by: https://github.com/weifengpy
ghstack dependencies: #170615, #167677

Co-authored-by: Pian Pawakapan <pianpwk@meta.com>
xgz2 pushed a commit that referenced this pull request Dec 22, 2025
Enforce tensor_meta is not none for new single-dim rules.

Allow tensor_meta to continue to be None for existing rules for now. We
should consider in the future asserting tensor_meta is required in
DTensorSpec, but for now we just try to limit the bleeding.
Pull Request resolved: #170827
Approved by: https://github.com/dolpm
ghstack dependencies: #170615, #167677, #170359
krastogi-in pushed a commit to krastogi-in/pytorch that referenced this pull request Jan 9, 2026
### Motivation
We find it is too difficult to implement sharding strategies for DTensor, and this slows progress towards full operator coverage, and increases likelihood of bugs in sharding rules.  We also expect to add new Placement types (e.g. more complete support for StridedShard in the near term, but possibly others as well), and the current formulation of sharding strategies is not scalable to adding new placement types.

A primary reason it's so difficult to write sharding rules today is that they combine 2 things: (a) a mathematical description of which in/out shardings are correct for the op, (b) premature runtime optimization to avoid considering too many combinations over N-D mesh.

The proposal is to remove (b) and focus just on (a).

### tl;dr
1. Write sharding prop rules in terms of 1 mesh dim and using a 'placeholder' for generic sharding types
e.g. matmul rule returns:
```
[
   ShardPlaceholder(0), Replicate() -> ShardPlaceholder(0),
   Replicate(), ShardPlaceholder(1), -> ShardPlaceholder(1),
   ShardPlaceholder(1), ShardPlaceholder(0) -> Partial()
]
```
2. After registration, each rule gets automatically expanded to include the real sharding types discovered in the inputs at runtime, and add 'full replication' rule.
e.g.
if inputs are fully replicated, we drop the placeholder rules and only use
```
[
   Replicate(), Replicate() -> Replicate()
]
```
if 'Shard' is discovered in inputs, we fill placeholders like
```
[
   Shard(0), Replicate() -> Shard(0),
   Replicate(), Shard(1), -> Shard(1),
   Shard(1), Shard(0) -> Partial(),
   Replicate(), Replicate() -> Replicate()
]
```
if 'Shard' and 'StridedShard' are both discovered in the inputs, we expand to
```
[
   Shard(0), Replicate() -> Shard(0),
   Replicate(), Shard(1), -> Shard(1),
   Shard(1), Shard(0) -> Partial(),
   StridedShard(0), Replicate() -> StridedShard(0),
   Replicate(), StridedShard(1), -> StridedShard(1),
   StridedShard(1), StridedShard(0) -> Partial(),
   Replicate(), Replicate() -> Replicate()
]
```
3. After filling the placeholders, we expand to N-D mesh and find the minimum cost
(a) full enumeration via itertools.product is implemented and gives exact parity with rules like 'einsum' today
(b) optimized solution, starting from input placements and iterating in the order of increasing cost until reaching a min-cost solution _without_ having to fully enumerate - under development/prototyping

### This PR
* defines a 'single_dim strategy' function and a ShardingPlaceholder
* adds a util for expanding a single_dim strategy into a regular strategy
* supports StridedShard automatically via ShardingPlaceholder expansion
* writes rules for mm and cat and uses unit tests to validate their expansion

### Next Steps (PR stack)
* Support pointwise and foreach ops in the single_dim infra
* Hook up single-dim strategies to sharding_prop (op registration)
* Start to use single_dim rules to replace existing rules
* Improve the runtime of searching the fully expanded strategy
* Explore using decomps together with single-dim rules to support more operators

Pull Request resolved: pytorch#167677
Approved by: https://github.com/weifengpy
ghstack dependencies: pytorch#170615

Co-authored-by: Pian Pawakapan <pianpwk@meta.com>
krastogi-in pushed a commit to krastogi-in/pytorch that referenced this pull request Jan 9, 2026
This reverts commit 32d0782.

Reverted pytorch#170359 on behalf of https://github.com/jeanschmidt due to Required to revert pytorch#167677 that is required to revert pytorch#170615 that is required to revert pytorch#170030 ([comment](pytorch#170359 (comment)))
krastogi-in pushed a commit to krastogi-in/pytorch that referenced this pull request Jan 9, 2026
This reverts commit c3e628e.

Reverted pytorch#167677 on behalf of https://github.com/jeanschmidt due to Required to revert pytorch#170615 that is required to rever pytorch#170030 ([comment](pytorch#167677 (comment)))
krastogi-in pushed a commit to krastogi-in/pytorch that referenced this pull request Jan 9, 2026
### Motivation
We find it is too difficult to implement sharding strategies for DTensor, and this slows progress towards full operator coverage, and increases likelihood of bugs in sharding rules.  We also expect to add new Placement types (e.g. more complete support for StridedShard in the near term, but possibly others as well), and the current formulation of sharding strategies is not scalable to adding new placement types.

A primary reason it's so difficult to write sharding rules today is that they combine 2 things: (a) a mathematical description of which in/out shardings are correct for the op, (b) premature runtime optimization to avoid considering too many combinations over N-D mesh.

The proposal is to remove (b) and focus just on (a).

### tl;dr
1. Write sharding prop rules in terms of 1 mesh dim and using a 'placeholder' for generic sharding types
e.g. matmul rule returns:
```
[
   ShardPlaceholder(0), Replicate() -> ShardPlaceholder(0),
   Replicate(), ShardPlaceholder(1), -> ShardPlaceholder(1),
   ShardPlaceholder(1), ShardPlaceholder(0) -> Partial()
]
```
2. After registration, each rule gets automatically expanded to include the real sharding types discovered in the inputs at runtime, and add 'full replication' rule.
e.g.
if inputs are fully replicated, we drop the placeholder rules and only use
```
[
   Replicate(), Replicate() -> Replicate()
]
```
if 'Shard' is discovered in inputs, we fill placeholders like
```
[
   Shard(0), Replicate() -> Shard(0),
   Replicate(), Shard(1), -> Shard(1),
   Shard(1), Shard(0) -> Partial(),
   Replicate(), Replicate() -> Replicate()
]
```
if 'Shard' and 'StridedShard' are both discovered in the inputs, we expand to
```
[
   Shard(0), Replicate() -> Shard(0),
   Replicate(), Shard(1), -> Shard(1),
   Shard(1), Shard(0) -> Partial(),
   StridedShard(0), Replicate() -> StridedShard(0),
   Replicate(), StridedShard(1), -> StridedShard(1),
   StridedShard(1), StridedShard(0) -> Partial(),
   Replicate(), Replicate() -> Replicate()
]
```
3. After filling the placeholders, we expand to N-D mesh and find the minimum cost
(a) full enumeration via itertools.product is implemented and gives exact parity with rules like 'einsum' today
(b) optimized solution, starting from input placements and iterating in the order of increasing cost until reaching a min-cost solution _without_ having to fully enumerate - under development/prototyping

### This PR
* defines a 'single_dim strategy' function and a ShardingPlaceholder
* adds a util for expanding a single_dim strategy into a regular strategy
* supports StridedShard automatically via ShardingPlaceholder expansion
* writes rules for mm and cat and uses unit tests to validate their expansion

### Next Steps (PR stack)
* Support pointwise and foreach ops in the single_dim infra
* Hook up single-dim strategies to sharding_prop (op registration)
* Start to use single_dim rules to replace existing rules
* Improve the runtime of searching the fully expanded strategy
* Explore using decomps together with single-dim rules to support more operators

Pull Request resolved: pytorch#167677
Approved by: https://github.com/weifengpy
ghstack dependencies: pytorch#170615

Co-authored-by: Pian Pawakapan <pianpwk@meta.com>
krastogi-in pushed a commit to krastogi-in/pytorch that referenced this pull request Jan 9, 2026
This PR adds the register_single_dim_strategy util,  and hooks it up to sharding_propagator.  It also tests the registration.

Notes:
* I didn't yet decide how multiple registrations should be handled.  I was planning to make it an error if you register twice for the same op for either single_dim or regular strategies.
* I took the cleanest path of integration for now in sharding_prop, reusing as much code as possible with the existing 'op_strategy' case.  I may have to change this later when integrating find_min_cost

Pull Request resolved: pytorch#170359
Approved by: https://github.com/weifengpy
ghstack dependencies: pytorch#170615, pytorch#167677

Co-authored-by: Pian Pawakapan <pianpwk@meta.com>
krastogi-in pushed a commit to krastogi-in/pytorch that referenced this pull request Jan 9, 2026
Enforce tensor_meta is not none for new single-dim rules.

Allow tensor_meta to continue to be None for existing rules for now. We
should consider in the future asserting tensor_meta is required in
DTensorSpec, but for now we just try to limit the bleeding.
Pull Request resolved: pytorch#170827
Approved by: https://github.com/dolpm
ghstack dependencies: pytorch#170615, pytorch#167677, pytorch#170359
@github-actions github-actions bot deleted the gh/wconstab/456/head branch January 18, 2026 02:21
SergeyTyshkevich pushed a commit to SergeyTyshkevich/chart2 that referenced this pull request Jan 19, 2026
ghstack-source-id: fed27a4
Pull Request resolved: pytorch/pytorch#167677
pytorchmergebot pushed a commit that referenced this pull request Feb 4, 2026
Following @tianyu-l's #130887

Adds support for ops with no sharding prop strategy, but a registered decomposition.

Now if sharding prop sees a decomposable op, it:
1. Runs the decomposed op under a custom TorchDispatchMode, which propagates the placements as side information (initially used a make_fx implementation, but this required a threading lock as it relies on [global state](https://github.com/pytorch/pytorch/blob/2a26c9a32661ee2b4b049e3bd1b889fc3af30880/torch/fx/_symbolic_trace.py#L1167))
2. Enumerates potential input placement combinations based on the actual input placements, on a single-dim mesh, then for each of them, propagates through torch_dispatch via sharding prop, while banning any intermediate redistributions.
3. Returns the expanded full-mesh strategy from the filtered strategies.

Some caveats:
- Since the dispatch mode runs sharding prop, the shard prop cache should kick in, both in the normal case (running the same op twice), and also when we recursively decompose (if op1 -> op2 -> some decomp, running op1 caches for op2).
- One common failure case is decompositions calling factory methods (e.g. [torch.ones, torch.arange](https://github.com/pytorch/pytorch/blob/41f42a0fc3ea1fbfdf05b4c030d7df815bdfe19d/torch/_decomp/decompositions.py#L818-L821)). The main problem seems to be assigning placements to these tensors, and it's not so obvious what their placements should be, especially when they might take in sharded sizes, and we can't completely detect when this is the case. For now, intermediate shard prop will fail (no sharding strategy; they don't take DTensor inputs), but a potential future improvement is to permit the full-Replicate case for these graphs.
- Sharding prop is currently via a `propagate_op_sharding` call, on explicit placement types. Once [single-dim strategy](#167677) coverage is broader, this should be doable on _ShardPlaceholders instead, making the enumeration & propagation process cheaper, though maybe more manual.
- (Maybe hackily) uses a fake 1-rank 1d mesh to do single-dim propagation

Removes the following xfails (+some more aten ops with decomp coverage, but still failing tests):
```
__rsub__
addmv
addr
alias_copy
all
any
count_nonzero
dist
expand_copy
fill
floor_divide
index_select
linalg.vecdot
masked_fill
mv
nn.functional.celu
nn.functional.channel_shuffle
nn.functional.elu
nn.functional.hardsigmoid
nn.functional.hardswish
nn.functional.hardtanh
nn.functional.leaky_relu
nn.functional.logsigmoid
nn.functional.margin_ranking_loss
nn.functional.mish
nn.functional.multilabel_soft_margin_loss
nn.functional.pairwise_distance
nn.functional.pixel_shuffle
nn.functional.pixel_unshuffle
nn.functional.prelu
nn.functional.relu6
nn.functional.selu
nn.functional.softplus
nn.functional.softshrink
nn.functional.triplet_margin_loss
nn.functional.triplet_margin_with_distance_loss
permute_copy
rsub
t_copy
trace
vdot
view_copy
```
Pull Request resolved: #171652
Approved by: https://github.com/wconstab
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: dtensor distributed tensor tag oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (dtensor) release notes category Reverted

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants