Skip to content

Revert "[DTensor] Refactor strategy/rule registration into dedicated module (#168221)"#170615

Closed
wconstab wants to merge 6 commits intogh/wconstab/479/basefrom
gh/wconstab/479/head
Closed

Revert "[DTensor] Refactor strategy/rule registration into dedicated module (#168221)"#170615
wconstab wants to merge 6 commits intogh/wconstab/479/basefrom
gh/wconstab/479/head

Conversation

@wconstab
Copy link
Copy Markdown
Contributor

@wconstab wconstab commented Dec 16, 2025

…module (#168221)"

This reverts commit cb3754f.

Reverting this change as it affects the import path of a publicly
used API.

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Dec 16, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/170615

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 1f25975 with merge base 1984725 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

wconstab added a commit that referenced this pull request Dec 16, 2025
…module (#168221)"

This reverts commit cb3754f.

Reverting this change as it affects the import path of a publicly
used API.

ghstack-source-id: d68ea51
Pull Request resolved: #170615
@pytorch-bot pytorch-bot Bot added ci-no-td Do not run TD on this PR ciflow/inductor labels Dec 16, 2025
… dedicated module (#168221)""

This reverts commit cb3754f.

Reverting this change as it affects the import path of a publicly
used API.

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Dec 16, 2025
…module (#168221)"

This reverts commit cb3754f.

Reverting this change as it affects the import path of a publicly
used API.

ghstack-source-id: 8c00738
Pull Request resolved: #170615
@wconstab wconstab added the release notes: distributed (dtensor) release notes category label Dec 16, 2025
… dedicated module (#168221)""

This reverts commit cb3754f.

Reverting this change as it affects the import path of a publicly
used API.

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Dec 17, 2025
…module (#168221)"

This reverts commit cb3754f.

Reverting this change as it affects the import path of a publicly
used API.

ghstack-source-id: 29c372c
Pull Request resolved: #170615
… dedicated module (#168221)""

This reverts commit cb3754f.

Reverting this change as it affects the import path of a publicly
used API.

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Dec 17, 2025
…module (#168221)"

This reverts commit cb3754f.

Reverting this change as it affects the import path of a publicly
used API.

ghstack-source-id: 97ee16a
Pull Request resolved: #170615
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Starting merge as part of PR stack under #167677

3 similar comments
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Starting merge as part of PR stack under #167677

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Starting merge as part of PR stack under #167677

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Starting merge as part of PR stack under #167677

… dedicated module (#168221)""

This reverts commit cb3754f.

Reverting this change as it affects the import path of a publicly
used API.

[ghstack-poisoned]
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Starting merge as part of PR stack under #167677

pytorchmergebot pushed a commit that referenced this pull request Dec 17, 2025
### Motivation
We find it is too difficult to implement sharding strategies for DTensor, and this slows progress towards full operator coverage, and increases likelihood of bugs in sharding rules.  We also expect to add new Placement types (e.g. more complete support for StridedShard in the near term, but possibly others as well), and the current formulation of sharding strategies is not scalable to adding new placement types.

A primary reason it's so difficult to write sharding rules today is that they combine 2 things: (a) a mathematical description of which in/out shardings are correct for the op, (b) premature runtime optimization to avoid considering too many combinations over N-D mesh.

The proposal is to remove (b) and focus just on (a).

### tl;dr
1. Write sharding prop rules in terms of 1 mesh dim and using a 'placeholder' for generic sharding types
e.g. matmul rule returns:
```
[
   ShardPlaceholder(0), Replicate() -> ShardPlaceholder(0),
   Replicate(), ShardPlaceholder(1), -> ShardPlaceholder(1),
   ShardPlaceholder(1), ShardPlaceholder(0) -> Partial()
]
```
2. After registration, each rule gets automatically expanded to include the real sharding types discovered in the inputs at runtime, and add 'full replication' rule.
e.g.
if inputs are fully replicated, we drop the placeholder rules and only use
```
[
   Replicate(), Replicate() -> Replicate()
]
```
if 'Shard' is discovered in inputs, we fill placeholders like
```
[
   Shard(0), Replicate() -> Shard(0),
   Replicate(), Shard(1), -> Shard(1),
   Shard(1), Shard(0) -> Partial(),
   Replicate(), Replicate() -> Replicate()
]
```
if 'Shard' and 'StridedShard' are both discovered in the inputs, we expand to
```
[
   Shard(0), Replicate() -> Shard(0),
   Replicate(), Shard(1), -> Shard(1),
   Shard(1), Shard(0) -> Partial(),
   StridedShard(0), Replicate() -> StridedShard(0),
   Replicate(), StridedShard(1), -> StridedShard(1),
   StridedShard(1), StridedShard(0) -> Partial(),
   Replicate(), Replicate() -> Replicate()
]
```
3. After filling the placeholders, we expand to N-D mesh and find the minimum cost
(a) full enumeration via itertools.product is implemented and gives exact parity with rules like 'einsum' today
(b) optimized solution, starting from input placements and iterating in the order of increasing cost until reaching a min-cost solution _without_ having to fully enumerate - under development/prototyping

### This PR
* defines a 'single_dim strategy' function and a ShardingPlaceholder
* adds a util for expanding a single_dim strategy into a regular strategy
* supports StridedShard automatically via ShardingPlaceholder expansion
* writes rules for mm and cat and uses unit tests to validate their expansion

### Next Steps (PR stack)
* Support pointwise and foreach ops in the single_dim infra
* Hook up single-dim strategies to sharding_prop (op registration)
* Start to use single_dim rules to replace existing rules
* Improve the runtime of searching the fully expanded strategy
* Explore using decomps together with single-dim rules to support more operators

Pull Request resolved: #167677
Approved by: https://github.com/weifengpy
ghstack dependencies: #170615

Co-authored-by: Pian Pawakapan <pianpwk@meta.com>
@jeanschmidt
Copy link
Copy Markdown
Contributor

@pytorchbot revert -m "Required to revert #170030" -c ghfirst

@jeanschmidt
Copy link
Copy Markdown
Contributor

@wconstab QQ what is the use case for us to land reverts as manual ghstack PRs over using the pytorchbot? Is there a particular workflow we don't fully support with pytorchbot commands?

weifengpy pushed a commit that referenced this pull request Dec 19, 2025
…module (#168221)"

This reverts commit cb3754f.

Reverting this change as it affects the import path of a publicly
used API.

ghstack-source-id: b787988
Pull Request resolved: #170615
pytorchmergebot pushed a commit that referenced this pull request Dec 19, 2025
### Motivation
We find it is too difficult to implement sharding strategies for DTensor, and this slows progress towards full operator coverage, and increases likelihood of bugs in sharding rules.  We also expect to add new Placement types (e.g. more complete support for StridedShard in the near term, but possibly others as well), and the current formulation of sharding strategies is not scalable to adding new placement types.

A primary reason it's so difficult to write sharding rules today is that they combine 2 things: (a) a mathematical description of which in/out shardings are correct for the op, (b) premature runtime optimization to avoid considering too many combinations over N-D mesh.

The proposal is to remove (b) and focus just on (a).

### tl;dr
1. Write sharding prop rules in terms of 1 mesh dim and using a 'placeholder' for generic sharding types
e.g. matmul rule returns:
```
[
   ShardPlaceholder(0), Replicate() -> ShardPlaceholder(0),
   Replicate(), ShardPlaceholder(1), -> ShardPlaceholder(1),
   ShardPlaceholder(1), ShardPlaceholder(0) -> Partial()
]
```
2. After registration, each rule gets automatically expanded to include the real sharding types discovered in the inputs at runtime, and add 'full replication' rule.
e.g.
if inputs are fully replicated, we drop the placeholder rules and only use
```
[
   Replicate(), Replicate() -> Replicate()
]
```
if 'Shard' is discovered in inputs, we fill placeholders like
```
[
   Shard(0), Replicate() -> Shard(0),
   Replicate(), Shard(1), -> Shard(1),
   Shard(1), Shard(0) -> Partial(),
   Replicate(), Replicate() -> Replicate()
]
```
if 'Shard' and 'StridedShard' are both discovered in the inputs, we expand to
```
[
   Shard(0), Replicate() -> Shard(0),
   Replicate(), Shard(1), -> Shard(1),
   Shard(1), Shard(0) -> Partial(),
   StridedShard(0), Replicate() -> StridedShard(0),
   Replicate(), StridedShard(1), -> StridedShard(1),
   StridedShard(1), StridedShard(0) -> Partial(),
   Replicate(), Replicate() -> Replicate()
]
```
3. After filling the placeholders, we expand to N-D mesh and find the minimum cost
(a) full enumeration via itertools.product is implemented and gives exact parity with rules like 'einsum' today
(b) optimized solution, starting from input placements and iterating in the order of increasing cost until reaching a min-cost solution _without_ having to fully enumerate - under development/prototyping

### This PR
* defines a 'single_dim strategy' function and a ShardingPlaceholder
* adds a util for expanding a single_dim strategy into a regular strategy
* supports StridedShard automatically via ShardingPlaceholder expansion
* writes rules for mm and cat and uses unit tests to validate their expansion

### Next Steps (PR stack)
* Support pointwise and foreach ops in the single_dim infra
* Hook up single-dim strategies to sharding_prop (op registration)
* Start to use single_dim rules to replace existing rules
* Improve the runtime of searching the fully expanded strategy
* Explore using decomps together with single-dim rules to support more operators

Pull Request resolved: #167677
Approved by: https://github.com/weifengpy
ghstack dependencies: #170615

Co-authored-by: Pian Pawakapan <pianpwk@meta.com>
pytorchmergebot pushed a commit that referenced this pull request Dec 19, 2025
This PR adds the register_single_dim_strategy util,  and hooks it up to sharding_propagator.  It also tests the registration.

Notes:
* I didn't yet decide how multiple registrations should be handled.  I was planning to make it an error if you register twice for the same op for either single_dim or regular strategies.
* I took the cleanest path of integration for now in sharding_prop, reusing as much code as possible with the existing 'op_strategy' case.  I may have to change this later when integrating find_min_cost

Pull Request resolved: #170359
Approved by: https://github.com/weifengpy
ghstack dependencies: #170615, #167677

Co-authored-by: Pian Pawakapan <pianpwk@meta.com>
majing921201 pushed a commit to majing921201/pytorch that referenced this pull request Dec 19, 2025
This reverts commit 32d0782.

Reverted pytorch#170359 on behalf of https://github.com/jeanschmidt due to Required to revert pytorch#167677 that is required to revert pytorch#170615 that is required to revert pytorch#170030 ([comment](pytorch#170359 (comment)))
majing921201 pushed a commit to majing921201/pytorch that referenced this pull request Dec 19, 2025
This reverts commit c3e628e.

Reverted pytorch#167677 on behalf of https://github.com/jeanschmidt due to Required to revert pytorch#170615 that is required to rever pytorch#170030 ([comment](pytorch#167677 (comment)))
majing921201 pushed a commit to majing921201/pytorch that referenced this pull request Dec 19, 2025
majing921201 pushed a commit to majing921201/pytorch that referenced this pull request Dec 19, 2025
…module (pytorch#168221)" (pytorch#170615)

This reverts commit cb3754f.

Reverting this change as it affects the import path of a publicly
used API.
Pull Request resolved: pytorch#170615
Approved by: https://github.com/wdvr, https://github.com/malfet
majing921201 pushed a commit to majing921201/pytorch that referenced this pull request Dec 19, 2025
### Motivation
We find it is too difficult to implement sharding strategies for DTensor, and this slows progress towards full operator coverage, and increases likelihood of bugs in sharding rules.  We also expect to add new Placement types (e.g. more complete support for StridedShard in the near term, but possibly others as well), and the current formulation of sharding strategies is not scalable to adding new placement types.

A primary reason it's so difficult to write sharding rules today is that they combine 2 things: (a) a mathematical description of which in/out shardings are correct for the op, (b) premature runtime optimization to avoid considering too many combinations over N-D mesh.

The proposal is to remove (b) and focus just on (a).

### tl;dr
1. Write sharding prop rules in terms of 1 mesh dim and using a 'placeholder' for generic sharding types
e.g. matmul rule returns:
```
[
   ShardPlaceholder(0), Replicate() -> ShardPlaceholder(0),
   Replicate(), ShardPlaceholder(1), -> ShardPlaceholder(1),
   ShardPlaceholder(1), ShardPlaceholder(0) -> Partial()
]
```
2. After registration, each rule gets automatically expanded to include the real sharding types discovered in the inputs at runtime, and add 'full replication' rule.
e.g.
if inputs are fully replicated, we drop the placeholder rules and only use
```
[
   Replicate(), Replicate() -> Replicate()
]
```
if 'Shard' is discovered in inputs, we fill placeholders like
```
[
   Shard(0), Replicate() -> Shard(0),
   Replicate(), Shard(1), -> Shard(1),
   Shard(1), Shard(0) -> Partial(),
   Replicate(), Replicate() -> Replicate()
]
```
if 'Shard' and 'StridedShard' are both discovered in the inputs, we expand to
```
[
   Shard(0), Replicate() -> Shard(0),
   Replicate(), Shard(1), -> Shard(1),
   Shard(1), Shard(0) -> Partial(),
   StridedShard(0), Replicate() -> StridedShard(0),
   Replicate(), StridedShard(1), -> StridedShard(1),
   StridedShard(1), StridedShard(0) -> Partial(),
   Replicate(), Replicate() -> Replicate()
]
```
3. After filling the placeholders, we expand to N-D mesh and find the minimum cost
(a) full enumeration via itertools.product is implemented and gives exact parity with rules like 'einsum' today
(b) optimized solution, starting from input placements and iterating in the order of increasing cost until reaching a min-cost solution _without_ having to fully enumerate - under development/prototyping

### This PR
* defines a 'single_dim strategy' function and a ShardingPlaceholder
* adds a util for expanding a single_dim strategy into a regular strategy
* supports StridedShard automatically via ShardingPlaceholder expansion
* writes rules for mm and cat and uses unit tests to validate their expansion

### Next Steps (PR stack)
* Support pointwise and foreach ops in the single_dim infra
* Hook up single-dim strategies to sharding_prop (op registration)
* Start to use single_dim rules to replace existing rules
* Improve the runtime of searching the fully expanded strategy
* Explore using decomps together with single-dim rules to support more operators

Pull Request resolved: pytorch#167677
Approved by: https://github.com/weifengpy
ghstack dependencies: pytorch#170615

Co-authored-by: Pian Pawakapan <pianpwk@meta.com>
majing921201 pushed a commit to majing921201/pytorch that referenced this pull request Dec 19, 2025
This PR adds the register_single_dim_strategy util,  and hooks it up to sharding_propagator.  It also tests the registration.

Notes:
* I didn't yet decide how multiple registrations should be handled.  I was planning to make it an error if you register twice for the same op for either single_dim or regular strategies.
* I took the cleanest path of integration for now in sharding_prop, reusing as much code as possible with the existing 'op_strategy' case.  I may have to change this later when integrating find_min_cost

Pull Request resolved: pytorch#170359
Approved by: https://github.com/weifengpy
ghstack dependencies: pytorch#170615, pytorch#167677

Co-authored-by: Pian Pawakapan <pianpwk@meta.com>
pytorchmergebot pushed a commit that referenced this pull request Dec 19, 2025
Enforce tensor_meta is not none for new single-dim rules.

Allow tensor_meta to continue to be None for existing rules for now. We
should consider in the future asserting tensor_meta is required in
DTensorSpec, but for now we just try to limit the bleeding.
Pull Request resolved: #170827
Approved by: https://github.com/dolpm
ghstack dependencies: #170615, #167677, #170359
xgz2 pushed a commit that referenced this pull request Dec 22, 2025
This reverts commit 32d0782.

Reverted #170359 on behalf of https://github.com/jeanschmidt due to Required to revert #167677 that is required to revert #170615 that is required to revert #170030 ([comment](#170359 (comment)))
xgz2 pushed a commit that referenced this pull request Dec 22, 2025
This reverts commit c3e628e.

Reverted #167677 on behalf of https://github.com/jeanschmidt due to Required to revert #170615 that is required to rever #170030 ([comment](#167677 (comment)))
xgz2 pushed a commit that referenced this pull request Dec 22, 2025
… module (#168221)" (#170615)

This reverts commit c65f67b.

Reverted #170615 on behalf of https://github.com/jeanschmidt due to Required to revert #170030 ([comment](#170615 (comment)))
xgz2 pushed a commit that referenced this pull request Dec 22, 2025
…module (#168221)" (#170615)

This reverts commit cb3754f.

Reverting this change as it affects the import path of a publicly
used API.
Pull Request resolved: #170615
Approved by: https://github.com/wdvr, https://github.com/malfet
xgz2 pushed a commit that referenced this pull request Dec 22, 2025
### Motivation
We find it is too difficult to implement sharding strategies for DTensor, and this slows progress towards full operator coverage, and increases likelihood of bugs in sharding rules.  We also expect to add new Placement types (e.g. more complete support for StridedShard in the near term, but possibly others as well), and the current formulation of sharding strategies is not scalable to adding new placement types.

A primary reason it's so difficult to write sharding rules today is that they combine 2 things: (a) a mathematical description of which in/out shardings are correct for the op, (b) premature runtime optimization to avoid considering too many combinations over N-D mesh.

The proposal is to remove (b) and focus just on (a).

### tl;dr
1. Write sharding prop rules in terms of 1 mesh dim and using a 'placeholder' for generic sharding types
e.g. matmul rule returns:
```
[
   ShardPlaceholder(0), Replicate() -> ShardPlaceholder(0),
   Replicate(), ShardPlaceholder(1), -> ShardPlaceholder(1),
   ShardPlaceholder(1), ShardPlaceholder(0) -> Partial()
]
```
2. After registration, each rule gets automatically expanded to include the real sharding types discovered in the inputs at runtime, and add 'full replication' rule.
e.g.
if inputs are fully replicated, we drop the placeholder rules and only use
```
[
   Replicate(), Replicate() -> Replicate()
]
```
if 'Shard' is discovered in inputs, we fill placeholders like
```
[
   Shard(0), Replicate() -> Shard(0),
   Replicate(), Shard(1), -> Shard(1),
   Shard(1), Shard(0) -> Partial(),
   Replicate(), Replicate() -> Replicate()
]
```
if 'Shard' and 'StridedShard' are both discovered in the inputs, we expand to
```
[
   Shard(0), Replicate() -> Shard(0),
   Replicate(), Shard(1), -> Shard(1),
   Shard(1), Shard(0) -> Partial(),
   StridedShard(0), Replicate() -> StridedShard(0),
   Replicate(), StridedShard(1), -> StridedShard(1),
   StridedShard(1), StridedShard(0) -> Partial(),
   Replicate(), Replicate() -> Replicate()
]
```
3. After filling the placeholders, we expand to N-D mesh and find the minimum cost
(a) full enumeration via itertools.product is implemented and gives exact parity with rules like 'einsum' today
(b) optimized solution, starting from input placements and iterating in the order of increasing cost until reaching a min-cost solution _without_ having to fully enumerate - under development/prototyping

### This PR
* defines a 'single_dim strategy' function and a ShardingPlaceholder
* adds a util for expanding a single_dim strategy into a regular strategy
* supports StridedShard automatically via ShardingPlaceholder expansion
* writes rules for mm and cat and uses unit tests to validate their expansion

### Next Steps (PR stack)
* Support pointwise and foreach ops in the single_dim infra
* Hook up single-dim strategies to sharding_prop (op registration)
* Start to use single_dim rules to replace existing rules
* Improve the runtime of searching the fully expanded strategy
* Explore using decomps together with single-dim rules to support more operators

Pull Request resolved: #167677
Approved by: https://github.com/weifengpy
ghstack dependencies: #170615

Co-authored-by: Pian Pawakapan <pianpwk@meta.com>
xgz2 pushed a commit that referenced this pull request Dec 22, 2025
This PR adds the register_single_dim_strategy util,  and hooks it up to sharding_propagator.  It also tests the registration.

Notes:
* I didn't yet decide how multiple registrations should be handled.  I was planning to make it an error if you register twice for the same op for either single_dim or regular strategies.
* I took the cleanest path of integration for now in sharding_prop, reusing as much code as possible with the existing 'op_strategy' case.  I may have to change this later when integrating find_min_cost

Pull Request resolved: #170359
Approved by: https://github.com/weifengpy
ghstack dependencies: #170615, #167677

Co-authored-by: Pian Pawakapan <pianpwk@meta.com>
xgz2 pushed a commit that referenced this pull request Dec 22, 2025
Enforce tensor_meta is not none for new single-dim rules.

Allow tensor_meta to continue to be None for existing rules for now. We
should consider in the future asserting tensor_meta is required in
DTensorSpec, but for now we just try to limit the bleeding.
Pull Request resolved: #170827
Approved by: https://github.com/dolpm
ghstack dependencies: #170615, #167677, #170359
krastogi-in pushed a commit to krastogi-in/pytorch that referenced this pull request Jan 9, 2026
…module (pytorch#168221)" (pytorch#170615)

This reverts commit cb3754f.

Reverting this change as it affects the import path of a publicly
used API.
Pull Request resolved: pytorch#170615
Approved by: https://github.com/wdvr, https://github.com/malfet
krastogi-in pushed a commit to krastogi-in/pytorch that referenced this pull request Jan 9, 2026
### Motivation
We find it is too difficult to implement sharding strategies for DTensor, and this slows progress towards full operator coverage, and increases likelihood of bugs in sharding rules.  We also expect to add new Placement types (e.g. more complete support for StridedShard in the near term, but possibly others as well), and the current formulation of sharding strategies is not scalable to adding new placement types.

A primary reason it's so difficult to write sharding rules today is that they combine 2 things: (a) a mathematical description of which in/out shardings are correct for the op, (b) premature runtime optimization to avoid considering too many combinations over N-D mesh.

The proposal is to remove (b) and focus just on (a).

### tl;dr
1. Write sharding prop rules in terms of 1 mesh dim and using a 'placeholder' for generic sharding types
e.g. matmul rule returns:
```
[
   ShardPlaceholder(0), Replicate() -> ShardPlaceholder(0),
   Replicate(), ShardPlaceholder(1), -> ShardPlaceholder(1),
   ShardPlaceholder(1), ShardPlaceholder(0) -> Partial()
]
```
2. After registration, each rule gets automatically expanded to include the real sharding types discovered in the inputs at runtime, and add 'full replication' rule.
e.g.
if inputs are fully replicated, we drop the placeholder rules and only use
```
[
   Replicate(), Replicate() -> Replicate()
]
```
if 'Shard' is discovered in inputs, we fill placeholders like
```
[
   Shard(0), Replicate() -> Shard(0),
   Replicate(), Shard(1), -> Shard(1),
   Shard(1), Shard(0) -> Partial(),
   Replicate(), Replicate() -> Replicate()
]
```
if 'Shard' and 'StridedShard' are both discovered in the inputs, we expand to
```
[
   Shard(0), Replicate() -> Shard(0),
   Replicate(), Shard(1), -> Shard(1),
   Shard(1), Shard(0) -> Partial(),
   StridedShard(0), Replicate() -> StridedShard(0),
   Replicate(), StridedShard(1), -> StridedShard(1),
   StridedShard(1), StridedShard(0) -> Partial(),
   Replicate(), Replicate() -> Replicate()
]
```
3. After filling the placeholders, we expand to N-D mesh and find the minimum cost
(a) full enumeration via itertools.product is implemented and gives exact parity with rules like 'einsum' today
(b) optimized solution, starting from input placements and iterating in the order of increasing cost until reaching a min-cost solution _without_ having to fully enumerate - under development/prototyping

### This PR
* defines a 'single_dim strategy' function and a ShardingPlaceholder
* adds a util for expanding a single_dim strategy into a regular strategy
* supports StridedShard automatically via ShardingPlaceholder expansion
* writes rules for mm and cat and uses unit tests to validate their expansion

### Next Steps (PR stack)
* Support pointwise and foreach ops in the single_dim infra
* Hook up single-dim strategies to sharding_prop (op registration)
* Start to use single_dim rules to replace existing rules
* Improve the runtime of searching the fully expanded strategy
* Explore using decomps together with single-dim rules to support more operators

Pull Request resolved: pytorch#167677
Approved by: https://github.com/weifengpy
ghstack dependencies: pytorch#170615

Co-authored-by: Pian Pawakapan <pianpwk@meta.com>
krastogi-in pushed a commit to krastogi-in/pytorch that referenced this pull request Jan 9, 2026
This reverts commit 32d0782.

Reverted pytorch#170359 on behalf of https://github.com/jeanschmidt due to Required to revert pytorch#167677 that is required to revert pytorch#170615 that is required to revert pytorch#170030 ([comment](pytorch#170359 (comment)))
krastogi-in pushed a commit to krastogi-in/pytorch that referenced this pull request Jan 9, 2026
This reverts commit c3e628e.

Reverted pytorch#167677 on behalf of https://github.com/jeanschmidt due to Required to revert pytorch#170615 that is required to rever pytorch#170030 ([comment](pytorch#167677 (comment)))
krastogi-in pushed a commit to krastogi-in/pytorch that referenced this pull request Jan 9, 2026
krastogi-in pushed a commit to krastogi-in/pytorch that referenced this pull request Jan 9, 2026
…module (pytorch#168221)" (pytorch#170615)

This reverts commit cb3754f.

Reverting this change as it affects the import path of a publicly
used API.
Pull Request resolved: pytorch#170615
Approved by: https://github.com/wdvr, https://github.com/malfet
krastogi-in pushed a commit to krastogi-in/pytorch that referenced this pull request Jan 9, 2026
### Motivation
We find it is too difficult to implement sharding strategies for DTensor, and this slows progress towards full operator coverage, and increases likelihood of bugs in sharding rules.  We also expect to add new Placement types (e.g. more complete support for StridedShard in the near term, but possibly others as well), and the current formulation of sharding strategies is not scalable to adding new placement types.

A primary reason it's so difficult to write sharding rules today is that they combine 2 things: (a) a mathematical description of which in/out shardings are correct for the op, (b) premature runtime optimization to avoid considering too many combinations over N-D mesh.

The proposal is to remove (b) and focus just on (a).

### tl;dr
1. Write sharding prop rules in terms of 1 mesh dim and using a 'placeholder' for generic sharding types
e.g. matmul rule returns:
```
[
   ShardPlaceholder(0), Replicate() -> ShardPlaceholder(0),
   Replicate(), ShardPlaceholder(1), -> ShardPlaceholder(1),
   ShardPlaceholder(1), ShardPlaceholder(0) -> Partial()
]
```
2. After registration, each rule gets automatically expanded to include the real sharding types discovered in the inputs at runtime, and add 'full replication' rule.
e.g.
if inputs are fully replicated, we drop the placeholder rules and only use
```
[
   Replicate(), Replicate() -> Replicate()
]
```
if 'Shard' is discovered in inputs, we fill placeholders like
```
[
   Shard(0), Replicate() -> Shard(0),
   Replicate(), Shard(1), -> Shard(1),
   Shard(1), Shard(0) -> Partial(),
   Replicate(), Replicate() -> Replicate()
]
```
if 'Shard' and 'StridedShard' are both discovered in the inputs, we expand to
```
[
   Shard(0), Replicate() -> Shard(0),
   Replicate(), Shard(1), -> Shard(1),
   Shard(1), Shard(0) -> Partial(),
   StridedShard(0), Replicate() -> StridedShard(0),
   Replicate(), StridedShard(1), -> StridedShard(1),
   StridedShard(1), StridedShard(0) -> Partial(),
   Replicate(), Replicate() -> Replicate()
]
```
3. After filling the placeholders, we expand to N-D mesh and find the minimum cost
(a) full enumeration via itertools.product is implemented and gives exact parity with rules like 'einsum' today
(b) optimized solution, starting from input placements and iterating in the order of increasing cost until reaching a min-cost solution _without_ having to fully enumerate - under development/prototyping

### This PR
* defines a 'single_dim strategy' function and a ShardingPlaceholder
* adds a util for expanding a single_dim strategy into a regular strategy
* supports StridedShard automatically via ShardingPlaceholder expansion
* writes rules for mm and cat and uses unit tests to validate their expansion

### Next Steps (PR stack)
* Support pointwise and foreach ops in the single_dim infra
* Hook up single-dim strategies to sharding_prop (op registration)
* Start to use single_dim rules to replace existing rules
* Improve the runtime of searching the fully expanded strategy
* Explore using decomps together with single-dim rules to support more operators

Pull Request resolved: pytorch#167677
Approved by: https://github.com/weifengpy
ghstack dependencies: pytorch#170615

Co-authored-by: Pian Pawakapan <pianpwk@meta.com>
krastogi-in pushed a commit to krastogi-in/pytorch that referenced this pull request Jan 9, 2026
This PR adds the register_single_dim_strategy util,  and hooks it up to sharding_propagator.  It also tests the registration.

Notes:
* I didn't yet decide how multiple registrations should be handled.  I was planning to make it an error if you register twice for the same op for either single_dim or regular strategies.
* I took the cleanest path of integration for now in sharding_prop, reusing as much code as possible with the existing 'op_strategy' case.  I may have to change this later when integrating find_min_cost

Pull Request resolved: pytorch#170359
Approved by: https://github.com/weifengpy
ghstack dependencies: pytorch#170615, pytorch#167677

Co-authored-by: Pian Pawakapan <pianpwk@meta.com>
krastogi-in pushed a commit to krastogi-in/pytorch that referenced this pull request Jan 9, 2026
Enforce tensor_meta is not none for new single-dim rules.

Allow tensor_meta to continue to be None for existing rules for now. We
should consider in the future asserting tensor_meta is required in
DTensorSpec, but for now we just try to limit the bleeding.
Pull Request resolved: pytorch#170827
Approved by: https://github.com/dolpm
ghstack dependencies: pytorch#170615, pytorch#167677, pytorch#170359
@github-actions github-actions Bot deleted the gh/wconstab/479/head branch January 18, 2026 02:21
SergeyTyshkevich pushed a commit to SergeyTyshkevich/chart2 that referenced this pull request Jan 19, 2026
…module (#168221)"

This reverts commit cb3754f.

Reverting this change as it affects the import path of a publicly
used API.

ghstack-source-id: b787988
Pull Request resolved: pytorch/pytorch#170615
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants