Revert "[DTensor] Refactor strategy/rule registration into dedicated module (#168221)"#170615
Closed
wconstab wants to merge 6 commits intogh/wconstab/479/basefrom
Closed
Revert "[DTensor] Refactor strategy/rule registration into dedicated module (#168221)"#170615wconstab wants to merge 6 commits intogh/wconstab/479/basefrom
wconstab wants to merge 6 commits intogh/wconstab/479/basefrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/170615
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 1f25975 with merge base 1984725 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
wconstab
added a commit
that referenced
this pull request
Dec 16, 2025
wconstab
added a commit
that referenced
this pull request
Dec 16, 2025
wdvr
approved these changes
Dec 16, 2025
malfet
approved these changes
Dec 16, 2025
wconstab
added a commit
that referenced
this pull request
Dec 17, 2025
wconstab
added a commit
that referenced
this pull request
Dec 17, 2025
Collaborator
|
Starting merge as part of PR stack under #167677 |
3 similar comments
Collaborator
|
Starting merge as part of PR stack under #167677 |
Collaborator
|
Starting merge as part of PR stack under #167677 |
Collaborator
|
Starting merge as part of PR stack under #167677 |
Collaborator
|
Starting merge as part of PR stack under #167677 |
pytorchmergebot
pushed a commit
that referenced
this pull request
Dec 17, 2025
### Motivation We find it is too difficult to implement sharding strategies for DTensor, and this slows progress towards full operator coverage, and increases likelihood of bugs in sharding rules. We also expect to add new Placement types (e.g. more complete support for StridedShard in the near term, but possibly others as well), and the current formulation of sharding strategies is not scalable to adding new placement types. A primary reason it's so difficult to write sharding rules today is that they combine 2 things: (a) a mathematical description of which in/out shardings are correct for the op, (b) premature runtime optimization to avoid considering too many combinations over N-D mesh. The proposal is to remove (b) and focus just on (a). ### tl;dr 1. Write sharding prop rules in terms of 1 mesh dim and using a 'placeholder' for generic sharding types e.g. matmul rule returns: ``` [ ShardPlaceholder(0), Replicate() -> ShardPlaceholder(0), Replicate(), ShardPlaceholder(1), -> ShardPlaceholder(1), ShardPlaceholder(1), ShardPlaceholder(0) -> Partial() ] ``` 2. After registration, each rule gets automatically expanded to include the real sharding types discovered in the inputs at runtime, and add 'full replication' rule. e.g. if inputs are fully replicated, we drop the placeholder rules and only use ``` [ Replicate(), Replicate() -> Replicate() ] ``` if 'Shard' is discovered in inputs, we fill placeholders like ``` [ Shard(0), Replicate() -> Shard(0), Replicate(), Shard(1), -> Shard(1), Shard(1), Shard(0) -> Partial(), Replicate(), Replicate() -> Replicate() ] ``` if 'Shard' and 'StridedShard' are both discovered in the inputs, we expand to ``` [ Shard(0), Replicate() -> Shard(0), Replicate(), Shard(1), -> Shard(1), Shard(1), Shard(0) -> Partial(), StridedShard(0), Replicate() -> StridedShard(0), Replicate(), StridedShard(1), -> StridedShard(1), StridedShard(1), StridedShard(0) -> Partial(), Replicate(), Replicate() -> Replicate() ] ``` 3. After filling the placeholders, we expand to N-D mesh and find the minimum cost (a) full enumeration via itertools.product is implemented and gives exact parity with rules like 'einsum' today (b) optimized solution, starting from input placements and iterating in the order of increasing cost until reaching a min-cost solution _without_ having to fully enumerate - under development/prototyping ### This PR * defines a 'single_dim strategy' function and a ShardingPlaceholder * adds a util for expanding a single_dim strategy into a regular strategy * supports StridedShard automatically via ShardingPlaceholder expansion * writes rules for mm and cat and uses unit tests to validate their expansion ### Next Steps (PR stack) * Support pointwise and foreach ops in the single_dim infra * Hook up single-dim strategies to sharding_prop (op registration) * Start to use single_dim rules to replace existing rules * Improve the runtime of searching the fully expanded strategy * Explore using decomps together with single-dim rules to support more operators Pull Request resolved: #167677 Approved by: https://github.com/weifengpy ghstack dependencies: #170615 Co-authored-by: Pian Pawakapan <pianpwk@meta.com>
This was referenced Dec 17, 2025
Contributor
|
@pytorchbot revert -m "Required to revert #170030" -c ghfirst |
Contributor
|
@wconstab QQ what is the use case for us to land reverts as manual ghstack PRs over using the pytorchbot? Is there a particular workflow we don't fully support with pytorchbot commands? |
weifengpy
pushed a commit
that referenced
this pull request
Dec 19, 2025
pytorchmergebot
pushed a commit
that referenced
this pull request
Dec 19, 2025
### Motivation We find it is too difficult to implement sharding strategies for DTensor, and this slows progress towards full operator coverage, and increases likelihood of bugs in sharding rules. We also expect to add new Placement types (e.g. more complete support for StridedShard in the near term, but possibly others as well), and the current formulation of sharding strategies is not scalable to adding new placement types. A primary reason it's so difficult to write sharding rules today is that they combine 2 things: (a) a mathematical description of which in/out shardings are correct for the op, (b) premature runtime optimization to avoid considering too many combinations over N-D mesh. The proposal is to remove (b) and focus just on (a). ### tl;dr 1. Write sharding prop rules in terms of 1 mesh dim and using a 'placeholder' for generic sharding types e.g. matmul rule returns: ``` [ ShardPlaceholder(0), Replicate() -> ShardPlaceholder(0), Replicate(), ShardPlaceholder(1), -> ShardPlaceholder(1), ShardPlaceholder(1), ShardPlaceholder(0) -> Partial() ] ``` 2. After registration, each rule gets automatically expanded to include the real sharding types discovered in the inputs at runtime, and add 'full replication' rule. e.g. if inputs are fully replicated, we drop the placeholder rules and only use ``` [ Replicate(), Replicate() -> Replicate() ] ``` if 'Shard' is discovered in inputs, we fill placeholders like ``` [ Shard(0), Replicate() -> Shard(0), Replicate(), Shard(1), -> Shard(1), Shard(1), Shard(0) -> Partial(), Replicate(), Replicate() -> Replicate() ] ``` if 'Shard' and 'StridedShard' are both discovered in the inputs, we expand to ``` [ Shard(0), Replicate() -> Shard(0), Replicate(), Shard(1), -> Shard(1), Shard(1), Shard(0) -> Partial(), StridedShard(0), Replicate() -> StridedShard(0), Replicate(), StridedShard(1), -> StridedShard(1), StridedShard(1), StridedShard(0) -> Partial(), Replicate(), Replicate() -> Replicate() ] ``` 3. After filling the placeholders, we expand to N-D mesh and find the minimum cost (a) full enumeration via itertools.product is implemented and gives exact parity with rules like 'einsum' today (b) optimized solution, starting from input placements and iterating in the order of increasing cost until reaching a min-cost solution _without_ having to fully enumerate - under development/prototyping ### This PR * defines a 'single_dim strategy' function and a ShardingPlaceholder * adds a util for expanding a single_dim strategy into a regular strategy * supports StridedShard automatically via ShardingPlaceholder expansion * writes rules for mm and cat and uses unit tests to validate their expansion ### Next Steps (PR stack) * Support pointwise and foreach ops in the single_dim infra * Hook up single-dim strategies to sharding_prop (op registration) * Start to use single_dim rules to replace existing rules * Improve the runtime of searching the fully expanded strategy * Explore using decomps together with single-dim rules to support more operators Pull Request resolved: #167677 Approved by: https://github.com/weifengpy ghstack dependencies: #170615 Co-authored-by: Pian Pawakapan <pianpwk@meta.com>
pytorchmergebot
pushed a commit
that referenced
this pull request
Dec 19, 2025
This PR adds the register_single_dim_strategy util, and hooks it up to sharding_propagator. It also tests the registration. Notes: * I didn't yet decide how multiple registrations should be handled. I was planning to make it an error if you register twice for the same op for either single_dim or regular strategies. * I took the cleanest path of integration for now in sharding_prop, reusing as much code as possible with the existing 'op_strategy' case. I may have to change this later when integrating find_min_cost Pull Request resolved: #170359 Approved by: https://github.com/weifengpy ghstack dependencies: #170615, #167677 Co-authored-by: Pian Pawakapan <pianpwk@meta.com>
majing921201
pushed a commit
to majing921201/pytorch
that referenced
this pull request
Dec 19, 2025
This reverts commit 32d0782. Reverted pytorch#170359 on behalf of https://github.com/jeanschmidt due to Required to revert pytorch#167677 that is required to revert pytorch#170615 that is required to revert pytorch#170030 ([comment](pytorch#170359 (comment)))
majing921201
pushed a commit
to majing921201/pytorch
that referenced
this pull request
Dec 19, 2025
This reverts commit c3e628e. Reverted pytorch#167677 on behalf of https://github.com/jeanschmidt due to Required to revert pytorch#170615 that is required to rever pytorch#170030 ([comment](pytorch#167677 (comment)))
majing921201
pushed a commit
to majing921201/pytorch
that referenced
this pull request
Dec 19, 2025
… module (pytorch#168221)" (pytorch#170615) This reverts commit c65f67b. Reverted pytorch#170615 on behalf of https://github.com/jeanschmidt due to Required to revert pytorch#170030 ([comment](pytorch#170615 (comment)))
majing921201
pushed a commit
to majing921201/pytorch
that referenced
this pull request
Dec 19, 2025
…module (pytorch#168221)" (pytorch#170615) This reverts commit cb3754f. Reverting this change as it affects the import path of a publicly used API. Pull Request resolved: pytorch#170615 Approved by: https://github.com/wdvr, https://github.com/malfet
majing921201
pushed a commit
to majing921201/pytorch
that referenced
this pull request
Dec 19, 2025
### Motivation We find it is too difficult to implement sharding strategies for DTensor, and this slows progress towards full operator coverage, and increases likelihood of bugs in sharding rules. We also expect to add new Placement types (e.g. more complete support for StridedShard in the near term, but possibly others as well), and the current formulation of sharding strategies is not scalable to adding new placement types. A primary reason it's so difficult to write sharding rules today is that they combine 2 things: (a) a mathematical description of which in/out shardings are correct for the op, (b) premature runtime optimization to avoid considering too many combinations over N-D mesh. The proposal is to remove (b) and focus just on (a). ### tl;dr 1. Write sharding prop rules in terms of 1 mesh dim and using a 'placeholder' for generic sharding types e.g. matmul rule returns: ``` [ ShardPlaceholder(0), Replicate() -> ShardPlaceholder(0), Replicate(), ShardPlaceholder(1), -> ShardPlaceholder(1), ShardPlaceholder(1), ShardPlaceholder(0) -> Partial() ] ``` 2. After registration, each rule gets automatically expanded to include the real sharding types discovered in the inputs at runtime, and add 'full replication' rule. e.g. if inputs are fully replicated, we drop the placeholder rules and only use ``` [ Replicate(), Replicate() -> Replicate() ] ``` if 'Shard' is discovered in inputs, we fill placeholders like ``` [ Shard(0), Replicate() -> Shard(0), Replicate(), Shard(1), -> Shard(1), Shard(1), Shard(0) -> Partial(), Replicate(), Replicate() -> Replicate() ] ``` if 'Shard' and 'StridedShard' are both discovered in the inputs, we expand to ``` [ Shard(0), Replicate() -> Shard(0), Replicate(), Shard(1), -> Shard(1), Shard(1), Shard(0) -> Partial(), StridedShard(0), Replicate() -> StridedShard(0), Replicate(), StridedShard(1), -> StridedShard(1), StridedShard(1), StridedShard(0) -> Partial(), Replicate(), Replicate() -> Replicate() ] ``` 3. After filling the placeholders, we expand to N-D mesh and find the minimum cost (a) full enumeration via itertools.product is implemented and gives exact parity with rules like 'einsum' today (b) optimized solution, starting from input placements and iterating in the order of increasing cost until reaching a min-cost solution _without_ having to fully enumerate - under development/prototyping ### This PR * defines a 'single_dim strategy' function and a ShardingPlaceholder * adds a util for expanding a single_dim strategy into a regular strategy * supports StridedShard automatically via ShardingPlaceholder expansion * writes rules for mm and cat and uses unit tests to validate their expansion ### Next Steps (PR stack) * Support pointwise and foreach ops in the single_dim infra * Hook up single-dim strategies to sharding_prop (op registration) * Start to use single_dim rules to replace existing rules * Improve the runtime of searching the fully expanded strategy * Explore using decomps together with single-dim rules to support more operators Pull Request resolved: pytorch#167677 Approved by: https://github.com/weifengpy ghstack dependencies: pytorch#170615 Co-authored-by: Pian Pawakapan <pianpwk@meta.com>
majing921201
pushed a commit
to majing921201/pytorch
that referenced
this pull request
Dec 19, 2025
This PR adds the register_single_dim_strategy util, and hooks it up to sharding_propagator. It also tests the registration. Notes: * I didn't yet decide how multiple registrations should be handled. I was planning to make it an error if you register twice for the same op for either single_dim or regular strategies. * I took the cleanest path of integration for now in sharding_prop, reusing as much code as possible with the existing 'op_strategy' case. I may have to change this later when integrating find_min_cost Pull Request resolved: pytorch#170359 Approved by: https://github.com/weifengpy ghstack dependencies: pytorch#170615, pytorch#167677 Co-authored-by: Pian Pawakapan <pianpwk@meta.com>
pytorchmergebot
pushed a commit
that referenced
this pull request
Dec 19, 2025
Enforce tensor_meta is not none for new single-dim rules. Allow tensor_meta to continue to be None for existing rules for now. We should consider in the future asserting tensor_meta is required in DTensorSpec, but for now we just try to limit the bleeding. Pull Request resolved: #170827 Approved by: https://github.com/dolpm ghstack dependencies: #170615, #167677, #170359
xgz2
pushed a commit
that referenced
this pull request
Dec 22, 2025
This reverts commit 32d0782. Reverted #170359 on behalf of https://github.com/jeanschmidt due to Required to revert #167677 that is required to revert #170615 that is required to revert #170030 ([comment](#170359 (comment)))
xgz2
pushed a commit
that referenced
this pull request
Dec 22, 2025
This reverts commit c3e628e. Reverted #167677 on behalf of https://github.com/jeanschmidt due to Required to revert #170615 that is required to rever #170030 ([comment](#167677 (comment)))
xgz2
pushed a commit
that referenced
this pull request
Dec 22, 2025
… module (#168221)" (#170615) This reverts commit c65f67b. Reverted #170615 on behalf of https://github.com/jeanschmidt due to Required to revert #170030 ([comment](#170615 (comment)))
xgz2
pushed a commit
that referenced
this pull request
Dec 22, 2025
…module (#168221)" (#170615) This reverts commit cb3754f. Reverting this change as it affects the import path of a publicly used API. Pull Request resolved: #170615 Approved by: https://github.com/wdvr, https://github.com/malfet
xgz2
pushed a commit
that referenced
this pull request
Dec 22, 2025
### Motivation We find it is too difficult to implement sharding strategies for DTensor, and this slows progress towards full operator coverage, and increases likelihood of bugs in sharding rules. We also expect to add new Placement types (e.g. more complete support for StridedShard in the near term, but possibly others as well), and the current formulation of sharding strategies is not scalable to adding new placement types. A primary reason it's so difficult to write sharding rules today is that they combine 2 things: (a) a mathematical description of which in/out shardings are correct for the op, (b) premature runtime optimization to avoid considering too many combinations over N-D mesh. The proposal is to remove (b) and focus just on (a). ### tl;dr 1. Write sharding prop rules in terms of 1 mesh dim and using a 'placeholder' for generic sharding types e.g. matmul rule returns: ``` [ ShardPlaceholder(0), Replicate() -> ShardPlaceholder(0), Replicate(), ShardPlaceholder(1), -> ShardPlaceholder(1), ShardPlaceholder(1), ShardPlaceholder(0) -> Partial() ] ``` 2. After registration, each rule gets automatically expanded to include the real sharding types discovered in the inputs at runtime, and add 'full replication' rule. e.g. if inputs are fully replicated, we drop the placeholder rules and only use ``` [ Replicate(), Replicate() -> Replicate() ] ``` if 'Shard' is discovered in inputs, we fill placeholders like ``` [ Shard(0), Replicate() -> Shard(0), Replicate(), Shard(1), -> Shard(1), Shard(1), Shard(0) -> Partial(), Replicate(), Replicate() -> Replicate() ] ``` if 'Shard' and 'StridedShard' are both discovered in the inputs, we expand to ``` [ Shard(0), Replicate() -> Shard(0), Replicate(), Shard(1), -> Shard(1), Shard(1), Shard(0) -> Partial(), StridedShard(0), Replicate() -> StridedShard(0), Replicate(), StridedShard(1), -> StridedShard(1), StridedShard(1), StridedShard(0) -> Partial(), Replicate(), Replicate() -> Replicate() ] ``` 3. After filling the placeholders, we expand to N-D mesh and find the minimum cost (a) full enumeration via itertools.product is implemented and gives exact parity with rules like 'einsum' today (b) optimized solution, starting from input placements and iterating in the order of increasing cost until reaching a min-cost solution _without_ having to fully enumerate - under development/prototyping ### This PR * defines a 'single_dim strategy' function and a ShardingPlaceholder * adds a util for expanding a single_dim strategy into a regular strategy * supports StridedShard automatically via ShardingPlaceholder expansion * writes rules for mm and cat and uses unit tests to validate their expansion ### Next Steps (PR stack) * Support pointwise and foreach ops in the single_dim infra * Hook up single-dim strategies to sharding_prop (op registration) * Start to use single_dim rules to replace existing rules * Improve the runtime of searching the fully expanded strategy * Explore using decomps together with single-dim rules to support more operators Pull Request resolved: #167677 Approved by: https://github.com/weifengpy ghstack dependencies: #170615 Co-authored-by: Pian Pawakapan <pianpwk@meta.com>
xgz2
pushed a commit
that referenced
this pull request
Dec 22, 2025
This PR adds the register_single_dim_strategy util, and hooks it up to sharding_propagator. It also tests the registration. Notes: * I didn't yet decide how multiple registrations should be handled. I was planning to make it an error if you register twice for the same op for either single_dim or regular strategies. * I took the cleanest path of integration for now in sharding_prop, reusing as much code as possible with the existing 'op_strategy' case. I may have to change this later when integrating find_min_cost Pull Request resolved: #170359 Approved by: https://github.com/weifengpy ghstack dependencies: #170615, #167677 Co-authored-by: Pian Pawakapan <pianpwk@meta.com>
xgz2
pushed a commit
that referenced
this pull request
Dec 22, 2025
Enforce tensor_meta is not none for new single-dim rules. Allow tensor_meta to continue to be None for existing rules for now. We should consider in the future asserting tensor_meta is required in DTensorSpec, but for now we just try to limit the bleeding. Pull Request resolved: #170827 Approved by: https://github.com/dolpm ghstack dependencies: #170615, #167677, #170359
krastogi-in
pushed a commit
to krastogi-in/pytorch
that referenced
this pull request
Jan 9, 2026
…module (pytorch#168221)" (pytorch#170615) This reverts commit cb3754f. Reverting this change as it affects the import path of a publicly used API. Pull Request resolved: pytorch#170615 Approved by: https://github.com/wdvr, https://github.com/malfet
krastogi-in
pushed a commit
to krastogi-in/pytorch
that referenced
this pull request
Jan 9, 2026
### Motivation We find it is too difficult to implement sharding strategies for DTensor, and this slows progress towards full operator coverage, and increases likelihood of bugs in sharding rules. We also expect to add new Placement types (e.g. more complete support for StridedShard in the near term, but possibly others as well), and the current formulation of sharding strategies is not scalable to adding new placement types. A primary reason it's so difficult to write sharding rules today is that they combine 2 things: (a) a mathematical description of which in/out shardings are correct for the op, (b) premature runtime optimization to avoid considering too many combinations over N-D mesh. The proposal is to remove (b) and focus just on (a). ### tl;dr 1. Write sharding prop rules in terms of 1 mesh dim and using a 'placeholder' for generic sharding types e.g. matmul rule returns: ``` [ ShardPlaceholder(0), Replicate() -> ShardPlaceholder(0), Replicate(), ShardPlaceholder(1), -> ShardPlaceholder(1), ShardPlaceholder(1), ShardPlaceholder(0) -> Partial() ] ``` 2. After registration, each rule gets automatically expanded to include the real sharding types discovered in the inputs at runtime, and add 'full replication' rule. e.g. if inputs are fully replicated, we drop the placeholder rules and only use ``` [ Replicate(), Replicate() -> Replicate() ] ``` if 'Shard' is discovered in inputs, we fill placeholders like ``` [ Shard(0), Replicate() -> Shard(0), Replicate(), Shard(1), -> Shard(1), Shard(1), Shard(0) -> Partial(), Replicate(), Replicate() -> Replicate() ] ``` if 'Shard' and 'StridedShard' are both discovered in the inputs, we expand to ``` [ Shard(0), Replicate() -> Shard(0), Replicate(), Shard(1), -> Shard(1), Shard(1), Shard(0) -> Partial(), StridedShard(0), Replicate() -> StridedShard(0), Replicate(), StridedShard(1), -> StridedShard(1), StridedShard(1), StridedShard(0) -> Partial(), Replicate(), Replicate() -> Replicate() ] ``` 3. After filling the placeholders, we expand to N-D mesh and find the minimum cost (a) full enumeration via itertools.product is implemented and gives exact parity with rules like 'einsum' today (b) optimized solution, starting from input placements and iterating in the order of increasing cost until reaching a min-cost solution _without_ having to fully enumerate - under development/prototyping ### This PR * defines a 'single_dim strategy' function and a ShardingPlaceholder * adds a util for expanding a single_dim strategy into a regular strategy * supports StridedShard automatically via ShardingPlaceholder expansion * writes rules for mm and cat and uses unit tests to validate their expansion ### Next Steps (PR stack) * Support pointwise and foreach ops in the single_dim infra * Hook up single-dim strategies to sharding_prop (op registration) * Start to use single_dim rules to replace existing rules * Improve the runtime of searching the fully expanded strategy * Explore using decomps together with single-dim rules to support more operators Pull Request resolved: pytorch#167677 Approved by: https://github.com/weifengpy ghstack dependencies: pytorch#170615 Co-authored-by: Pian Pawakapan <pianpwk@meta.com>
krastogi-in
pushed a commit
to krastogi-in/pytorch
that referenced
this pull request
Jan 9, 2026
This reverts commit 32d0782. Reverted pytorch#170359 on behalf of https://github.com/jeanschmidt due to Required to revert pytorch#167677 that is required to revert pytorch#170615 that is required to revert pytorch#170030 ([comment](pytorch#170359 (comment)))
krastogi-in
pushed a commit
to krastogi-in/pytorch
that referenced
this pull request
Jan 9, 2026
This reverts commit c3e628e. Reverted pytorch#167677 on behalf of https://github.com/jeanschmidt due to Required to revert pytorch#170615 that is required to rever pytorch#170030 ([comment](pytorch#167677 (comment)))
krastogi-in
pushed a commit
to krastogi-in/pytorch
that referenced
this pull request
Jan 9, 2026
… module (pytorch#168221)" (pytorch#170615) This reverts commit c65f67b. Reverted pytorch#170615 on behalf of https://github.com/jeanschmidt due to Required to revert pytorch#170030 ([comment](pytorch#170615 (comment)))
krastogi-in
pushed a commit
to krastogi-in/pytorch
that referenced
this pull request
Jan 9, 2026
…module (pytorch#168221)" (pytorch#170615) This reverts commit cb3754f. Reverting this change as it affects the import path of a publicly used API. Pull Request resolved: pytorch#170615 Approved by: https://github.com/wdvr, https://github.com/malfet
krastogi-in
pushed a commit
to krastogi-in/pytorch
that referenced
this pull request
Jan 9, 2026
### Motivation We find it is too difficult to implement sharding strategies for DTensor, and this slows progress towards full operator coverage, and increases likelihood of bugs in sharding rules. We also expect to add new Placement types (e.g. more complete support for StridedShard in the near term, but possibly others as well), and the current formulation of sharding strategies is not scalable to adding new placement types. A primary reason it's so difficult to write sharding rules today is that they combine 2 things: (a) a mathematical description of which in/out shardings are correct for the op, (b) premature runtime optimization to avoid considering too many combinations over N-D mesh. The proposal is to remove (b) and focus just on (a). ### tl;dr 1. Write sharding prop rules in terms of 1 mesh dim and using a 'placeholder' for generic sharding types e.g. matmul rule returns: ``` [ ShardPlaceholder(0), Replicate() -> ShardPlaceholder(0), Replicate(), ShardPlaceholder(1), -> ShardPlaceholder(1), ShardPlaceholder(1), ShardPlaceholder(0) -> Partial() ] ``` 2. After registration, each rule gets automatically expanded to include the real sharding types discovered in the inputs at runtime, and add 'full replication' rule. e.g. if inputs are fully replicated, we drop the placeholder rules and only use ``` [ Replicate(), Replicate() -> Replicate() ] ``` if 'Shard' is discovered in inputs, we fill placeholders like ``` [ Shard(0), Replicate() -> Shard(0), Replicate(), Shard(1), -> Shard(1), Shard(1), Shard(0) -> Partial(), Replicate(), Replicate() -> Replicate() ] ``` if 'Shard' and 'StridedShard' are both discovered in the inputs, we expand to ``` [ Shard(0), Replicate() -> Shard(0), Replicate(), Shard(1), -> Shard(1), Shard(1), Shard(0) -> Partial(), StridedShard(0), Replicate() -> StridedShard(0), Replicate(), StridedShard(1), -> StridedShard(1), StridedShard(1), StridedShard(0) -> Partial(), Replicate(), Replicate() -> Replicate() ] ``` 3. After filling the placeholders, we expand to N-D mesh and find the minimum cost (a) full enumeration via itertools.product is implemented and gives exact parity with rules like 'einsum' today (b) optimized solution, starting from input placements and iterating in the order of increasing cost until reaching a min-cost solution _without_ having to fully enumerate - under development/prototyping ### This PR * defines a 'single_dim strategy' function and a ShardingPlaceholder * adds a util for expanding a single_dim strategy into a regular strategy * supports StridedShard automatically via ShardingPlaceholder expansion * writes rules for mm and cat and uses unit tests to validate their expansion ### Next Steps (PR stack) * Support pointwise and foreach ops in the single_dim infra * Hook up single-dim strategies to sharding_prop (op registration) * Start to use single_dim rules to replace existing rules * Improve the runtime of searching the fully expanded strategy * Explore using decomps together with single-dim rules to support more operators Pull Request resolved: pytorch#167677 Approved by: https://github.com/weifengpy ghstack dependencies: pytorch#170615 Co-authored-by: Pian Pawakapan <pianpwk@meta.com>
krastogi-in
pushed a commit
to krastogi-in/pytorch
that referenced
this pull request
Jan 9, 2026
This PR adds the register_single_dim_strategy util, and hooks it up to sharding_propagator. It also tests the registration. Notes: * I didn't yet decide how multiple registrations should be handled. I was planning to make it an error if you register twice for the same op for either single_dim or regular strategies. * I took the cleanest path of integration for now in sharding_prop, reusing as much code as possible with the existing 'op_strategy' case. I may have to change this later when integrating find_min_cost Pull Request resolved: pytorch#170359 Approved by: https://github.com/weifengpy ghstack dependencies: pytorch#170615, pytorch#167677 Co-authored-by: Pian Pawakapan <pianpwk@meta.com>
krastogi-in
pushed a commit
to krastogi-in/pytorch
that referenced
this pull request
Jan 9, 2026
Enforce tensor_meta is not none for new single-dim rules. Allow tensor_meta to continue to be None for existing rules for now. We should consider in the future asserting tensor_meta is required in DTensorSpec, but for now we just try to limit the bleeding. Pull Request resolved: pytorch#170827 Approved by: https://github.com/dolpm ghstack dependencies: pytorch#170615, pytorch#167677, pytorch#170359
SergeyTyshkevich
pushed a commit
to SergeyTyshkevich/chart2
that referenced
this pull request
Jan 19, 2026
…module (#168221)" This reverts commit cb3754f. Reverting this change as it affects the import path of a publicly used API. ghstack-source-id: b787988 Pull Request resolved: pytorch/pytorch#170615
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Stack from ghstack (oldest at bottom):
This reverts commit cb3754f.
Reverting this change as it affects the import path of a publicly
used API.