[DTensor] Single Dim Strategy infra#167677
[DTensor] Single Dim Strategy infra#167677wconstab wants to merge 28 commits intogh/wconstab/456/basefrom
Conversation
WIP for now, not ready to land Experimenting with the idea of decomposing op strategies into - a core function that proposes a single-mesh-dim strategy - automatic expansion to the mesh inside the registration mechanism Also, plan to add a 'ShardingPlacholder' to use for writing single-mesh-dim strategies in a way that can be expanded at runtime to any type of sharding discovered in the inputs. For now, this relies on full enumeration of the single-dim strategy onto the full mesh, and full enumeration of the combinations of different sharding placements discovered in the input, but we should be able to replace this with an algorithm to expand iteratively following the path of lowest redistribution cost. [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/167677
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 1 PendingAs of commit cbef950 with merge base 1984725 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
WIP for now, not ready to land Experimenting with the idea of decomposing op strategies into - a core function that proposes a single-mesh-dim strategy - automatic expansion to the mesh inside the registration mechanism Also, plan to add a 'ShardingPlacholder' to use for writing single-mesh-dim strategies in a way that can be expanded at runtime to any type of sharding discovered in the inputs. For now, this relies on full enumeration of the single-dim strategy onto the full mesh, and full enumeration of the combinations of different sharding placements discovered in the input, but we should be able to replace this with an algorithm to expand iteratively following the path of lowest redistribution cost. ghstack-source-id: ed6977e Pull Request resolved: #167677
|
One of the things that seemed important to me when doing this work, is to making the testing is actually enough to detect problems. @tianyu-l also had the idea of fuzzing through input placements to try to "discover" if sharding rules were complete or not. |
WIP for now, not ready to land Experimenting with the idea of decomposing op strategies into - a core function that proposes a single-mesh-dim strategy - automatic expansion to the mesh inside the registration mechanism Also, plan to add a 'ShardingPlacholder' to use for writing single-mesh-dim strategies in a way that can be expanded at runtime to any type of sharding discovered in the inputs. For now, this relies on full enumeration of the single-dim strategy onto the full mesh, and full enumeration of the combinations of different sharding placements discovered in the input, but we should be able to replace this with an algorithm to expand iteratively following the path of lowest redistribution cost. ghstack-source-id: ed6977e Pull Request resolved: #167677
| to cover all the mesh dims present in the runtime inputs. | ||
| """ | ||
|
|
||
| def expanded_registration_wrapper( |
There was a problem hiding this comment.
This function seems overlapping largely with expand_to_full_mesh_op_strategy, but it sounds you are tackling a different / harder problem?
In general, I found the PR summary is a bit hard to follow (likely because I missed some meetings / discussions), do we have an RFC or meeting notes I could follow?
There was a problem hiding this comment.
I added some details to the PR desc. There isn't an actual RFC but this can serve as a mini-RFC for now
There was a problem hiding this comment.
This function seems overlapping largely with expand_to_full_mesh_op_strategy, but it sounds you are tackling a different / harder problem?
I have addressed this. I was able to use the original expand helper.
WIP for now, not ready to land Experimenting with the idea of decomposing op strategies into - a core function that proposes a single-mesh-dim strategy - automatic expansion to the mesh inside the registration mechanism Also, plan to add a 'ShardingPlacholder' to use for writing single-mesh-dim strategies in a way that can be expanded at runtime to any type of sharding discovered in the inputs. For now, this relies on full enumeration of the single-dim strategy onto the full mesh, and full enumeration of the combinations of different sharding placements discovered in the input, but we should be able to replace this with an algorithm to expand iteratively following the path of lowest redistribution cost. ghstack-source-id: ed6977e Pull Request resolved: #167677
WIP for now, not ready to land Experimenting with the idea of decomposing op strategies into - a core function that proposes a single-mesh-dim strategy - automatic expansion to the mesh inside the registration mechanism Also, plan to add a 'ShardingPlacholder' to use for writing single-mesh-dim strategies in a way that can be expanded at runtime to any type of sharding discovered in the inputs. For now, this relies on full enumeration of the single-dim strategy onto the full mesh, and full enumeration of the combinations of different sharding placements discovered in the input, but we should be able to replace this with an algorithm to expand iteratively following the path of lowest redistribution cost. ghstack-source-id: ed6977e Pull Request resolved: #167677
WIP for now, not ready to land Experimenting with the idea of decomposing op strategies into - a core function that proposes a single-mesh-dim strategy - automatic expansion to the mesh inside the registration mechanism Also, plan to add a 'ShardingPlacholder' to use for writing single-mesh-dim strategies in a way that can be expanded at runtime to any type of sharding discovered in the inputs. For now, this relies on full enumeration of the single-dim strategy onto the full mesh, and full enumeration of the combinations of different sharding placements discovered in the input, but we should be able to replace this with an algorithm to expand iteratively following the path of lowest redistribution cost. cc H-Huang awgu wanchaol fegin fduwjj wz337 d4l3k pragupta msaroufim dcci tianyu-l XilunWu SherlockNoMad [ghstack-poisoned]
WIP for now, not ready to land Experimenting with the idea of decomposing op strategies into - a core function that proposes a single-mesh-dim strategy - automatic expansion to the mesh inside the registration mechanism Also, plan to add a 'ShardingPlacholder' to use for writing single-mesh-dim strategies in a way that can be expanded at runtime to any type of sharding discovered in the inputs. For now, this relies on full enumeration of the single-dim strategy onto the full mesh, and full enumeration of the combinations of different sharding placements discovered in the input, but we should be able to replace this with an algorithm to expand iteratively following the path of lowest redistribution cost. [ghstack-poisoned]
WIP for now, not ready to land Experimenting with the idea of decomposing op strategies into - a core function that proposes a single-mesh-dim strategy - automatic expansion to the mesh inside the registration mechanism Also, plan to add a 'ShardingPlacholder' to use for writing single-mesh-dim strategies in a way that can be expanded at runtime to any type of sharding discovered in the inputs. For now, this relies on full enumeration of the single-dim strategy onto the full mesh, and full enumeration of the combinations of different sharding placements discovered in the input, but we should be able to replace this with an algorithm to expand iteratively following the path of lowest redistribution cost. [ghstack-poisoned]
**This diff stack is a prototype/experiment.** ### Motivation We find it is too difficult to implement sharding strategies for DTensor, and this slows progress towards full operator coverage, and increases likelihood of bugs in sharding rules. We also expect to add new Placement types (e.g. more complete support for StridedShard in the near term, but possibly others as well), and the current formulation of sharding strategies is not scalable to adding new placement types. ### Specific Goals 1. [**not prototyped**]Op Strategies should automatically support StridedShard and any other 'purely sharding' placement type without additional work. Use a 'sharding placeholder' to declare sharding rules (e.g. S(0) , S(0) -> S(0) applies equally well to Shard(0) or StridedShard(0, sf=2) - we can make this substitution outside of the per-op rule. a. constraints such as 'must evenly shard' can be moved onto the variable, and filtered against real tensor/mesh dims during expansion time 3. [**prototyped**]Focus Op Strategies on a single mesh-dim, and rely on standard infra to expand to the full mesh 4. [**prototyped**]Create the possible input-output sharding specs without looking at input args and doing premature filtering, which trades code complexity for runtime optimization. Try to regain the runtime savings in another common infra layer during 'expansion' to the mesh ### Experiment Approach This stack implements some of the above goals, and then implements some operators (so far, mm and pointwise). I'm working on figuring out what the right tests and metrics to collect are before going further. If you have concerns about whether this approach would work for a particular op, i could try that. * One major difference in this proposal is the 'premature optimization' in the strategy is removed. How much does this matter? Can we recover some of the time by doing filtering during mesh expansion, or by changing to expand only to some 'min cost' set of states around the input placements? * How can we test if the new strategy is correct? Can we tell if we're missing a strategy or if we're issuing a wrong strategy? [ghstack-poisoned]
WIP for now, not ready to land Experimenting with the idea of decomposing op strategies into - a core function that proposes a single-mesh-dim strategy - automatic expansion to the mesh inside the registration mechanism Also, plan to add a 'ShardingPlacholder' to use for writing single-mesh-dim strategies in a way that can be expanded at runtime to any type of sharding discovered in the inputs. For now, this relies on full enumeration of the single-dim strategy onto the full mesh, and full enumeration of the combinations of different sharding placements discovered in the input, but we should be able to replace this with an algorithm to expand iteratively following the path of lowest redistribution cost. ghstack-source-id: 6c772b3 Pull Request resolved: #167677
**This diff stack is a prototype/experiment.** ### Motivation We find it is too difficult to implement sharding strategies for DTensor, and this slows progress towards full operator coverage, and increases likelihood of bugs in sharding rules. We also expect to add new Placement types (e.g. more complete support for StridedShard in the near term, but possibly others as well), and the current formulation of sharding strategies is not scalable to adding new placement types. ### Specific Goals 1. [**not prototyped**]Op Strategies should automatically support StridedShard and any other 'purely sharding' placement type without additional work. Use a 'sharding placeholder' to declare sharding rules (e.g. S(0) , S(0) -> S(0) applies equally well to Shard(0) or StridedShard(0, sf=2) - we can make this substitution outside of the per-op rule. a. constraints such as 'must evenly shard' can be moved onto the variable, and filtered against real tensor/mesh dims during expansion time 3. [**prototyped**]Focus Op Strategies on a single mesh-dim, and rely on standard infra to expand to the full mesh 4. [**prototyped**]Create the possible input-output sharding specs without looking at input args and doing premature filtering, which trades code complexity for runtime optimization. Try to regain the runtime savings in another common infra layer during 'expansion' to the mesh ### Experiment Approach This stack implements some of the above goals, and then implements some operators (so far, mm and pointwise). I'm working on figuring out what the right tests and metrics to collect are before going further. If you have concerns about whether this approach would work for a particular op, i could try that. * One major difference in this proposal is the 'premature optimization' in the strategy is removed. How much does this matter? Can we recover some of the time by doing filtering during mesh expansion, or by changing to expand only to some 'min cost' set of states around the input placements? * How can we test if the new strategy is correct? Can we tell if we're missing a strategy or if we're issuing a wrong strategy? [ghstack-poisoned]
**This diff stack is a prototype/experiment.** ### Motivation We find it is too difficult to implement sharding strategies for DTensor, and this slows progress towards full operator coverage, and increases likelihood of bugs in sharding rules. We also expect to add new Placement types (e.g. more complete support for StridedShard in the near term, but possibly others as well), and the current formulation of sharding strategies is not scalable to adding new placement types. ### Specific Goals 1. [**not prototyped**]Op Strategies should automatically support StridedShard and any other 'purely sharding' placement type without additional work. Use a 'sharding placeholder' to declare sharding rules (e.g. S(0) , S(0) -> S(0) applies equally well to Shard(0) or StridedShard(0, sf=2) - we can make this substitution outside of the per-op rule. a. constraints such as 'must evenly shard' can be moved onto the variable, and filtered against real tensor/mesh dims during expansion time 3. [**prototyped**]Focus Op Strategies on a single mesh-dim, and rely on standard infra to expand to the full mesh 4. [**prototyped**]Create the possible input-output sharding specs without looking at input args and doing premature filtering, which trades code complexity for runtime optimization. Try to regain the runtime savings in another common infra layer during 'expansion' to the mesh ### Experiment Approach This stack implements some of the above goals, and then implements some operators (so far, mm and pointwise). I'm working on figuring out what the right tests and metrics to collect are before going further. If you have concerns about whether this approach would work for a particular op, i could try that. * One major difference in this proposal is the 'premature optimization' in the strategy is removed. How much does this matter? Can we recover some of the time by doing filtering during mesh expansion, or by changing to expand only to some 'min cost' set of states around the input placements? * How can we test if the new strategy is correct? Can we tell if we're missing a strategy or if we're issuing a wrong strategy? [ghstack-poisoned]
**This diff stack is a prototype/experiment.** ### Motivation We find it is too difficult to implement sharding strategies for DTensor, and this slows progress towards full operator coverage, and increases likelihood of bugs in sharding rules. We also expect to add new Placement types (e.g. more complete support for StridedShard in the near term, but possibly others as well), and the current formulation of sharding strategies is not scalable to adding new placement types. ### Specific Goals 1. [**not prototyped**]Op Strategies should automatically support StridedShard and any other 'purely sharding' placement type without additional work. Use a 'sharding placeholder' to declare sharding rules (e.g. S(0) , S(0) -> S(0) applies equally well to Shard(0) or StridedShard(0, sf=2) - we can make this substitution outside of the per-op rule. a. constraints such as 'must evenly shard' can be moved onto the variable, and filtered against real tensor/mesh dims during expansion time 3. [**prototyped**]Focus Op Strategies on a single mesh-dim, and rely on standard infra to expand to the full mesh 4. [**prototyped**]Create the possible input-output sharding specs without looking at input args and doing premature filtering, which trades code complexity for runtime optimization. Try to regain the runtime savings in another common infra layer during 'expansion' to the mesh ### Experiment Approach This stack implements some of the above goals, and then implements some operators (so far, mm and pointwise). I'm working on figuring out what the right tests and metrics to collect are before going further. If you have concerns about whether this approach would work for a particular op, i could try that. * One major difference in this proposal is the 'premature optimization' in the strategy is removed. How much does this matter? Can we recover some of the time by doing filtering during mesh expansion, or by changing to expand only to some 'min cost' set of states around the input placements? * How can we test if the new strategy is correct? Can we tell if we're missing a strategy or if we're issuing a wrong strategy? [ghstack-poisoned]
**This diff stack is a prototype/experiment.** ### Motivation We find it is too difficult to implement sharding strategies for DTensor, and this slows progress towards full operator coverage, and increases likelihood of bugs in sharding rules. We also expect to add new Placement types (e.g. more complete support for StridedShard in the near term, but possibly others as well), and the current formulation of sharding strategies is not scalable to adding new placement types. ### Specific Goals 1. [**not prototyped**]Op Strategies should automatically support StridedShard and any other 'purely sharding' placement type without additional work. Use a 'sharding placeholder' to declare sharding rules (e.g. S(0) , S(0) -> S(0) applies equally well to Shard(0) or StridedShard(0, sf=2) - we can make this substitution outside of the per-op rule. a. constraints such as 'must evenly shard' can be moved onto the variable, and filtered against real tensor/mesh dims during expansion time 3. [**prototyped**]Focus Op Strategies on a single mesh-dim, and rely on standard infra to expand to the full mesh 4. [**prototyped**]Create the possible input-output sharding specs without looking at input args and doing premature filtering, which trades code complexity for runtime optimization. Try to regain the runtime savings in another common infra layer during 'expansion' to the mesh ### Experiment Approach This stack implements some of the above goals, and then implements some operators (so far, mm and pointwise). I'm working on figuring out what the right tests and metrics to collect are before going further. If you have concerns about whether this approach would work for a particular op, i could try that. * One major difference in this proposal is the 'premature optimization' in the strategy is removed. How much does this matter? Can we recover some of the time by doing filtering during mesh expansion, or by changing to expand only to some 'min cost' set of states around the input placements? * How can we test if the new strategy is correct? Can we tell if we're missing a strategy or if we're issuing a wrong strategy? [ghstack-poisoned]
**This diff stack is a prototype/experiment.** ### Motivation We find it is too difficult to implement sharding strategies for DTensor, and this slows progress towards full operator coverage, and increases likelihood of bugs in sharding rules. We also expect to add new Placement types (e.g. more complete support for StridedShard in the near term, but possibly others as well), and the current formulation of sharding strategies is not scalable to adding new placement types. ### Specific Goals 1. [**not prototyped**]Op Strategies should automatically support StridedShard and any other 'purely sharding' placement type without additional work. Use a 'sharding placeholder' to declare sharding rules (e.g. S(0) , S(0) -> S(0) applies equally well to Shard(0) or StridedShard(0, sf=2) - we can make this substitution outside of the per-op rule. a. constraints such as 'must evenly shard' can be moved onto the variable, and filtered against real tensor/mesh dims during expansion time 3. [**prototyped**]Focus Op Strategies on a single mesh-dim, and rely on standard infra to expand to the full mesh 4. [**prototyped**]Create the possible input-output sharding specs without looking at input args and doing premature filtering, which trades code complexity for runtime optimization. Try to regain the runtime savings in another common infra layer during 'expansion' to the mesh ### Experiment Approach This stack implements some of the above goals, and then implements some operators (so far, mm and pointwise). I'm working on figuring out what the right tests and metrics to collect are before going further. If you have concerns about whether this approach would work for a particular op, i could try that. * One major difference in this proposal is the 'premature optimization' in the strategy is removed. How much does this matter? Can we recover some of the time by doing filtering during mesh expansion, or by changing to expand only to some 'min cost' set of states around the input placements? * How can we test if the new strategy is correct? Can we tell if we're missing a strategy or if we're issuing a wrong strategy? [ghstack-poisoned]
**This diff stack is a prototype/experiment.** ### Motivation We find it is too difficult to implement sharding strategies for DTensor, and this slows progress towards full operator coverage, and increases likelihood of bugs in sharding rules. We also expect to add new Placement types (e.g. more complete support for StridedShard in the near term, but possibly others as well), and the current formulation of sharding strategies is not scalable to adding new placement types. ### Specific Goals 1. [**prototyped**]Op Strategies should automatically support StridedShard and any other 'purely sharding' placement type without additional work. Use a 'sharding placeholder' to declare sharding rules (e.g. S(0) , S(0) -> S(0) applies equally well to Shard(0) or StridedShard(0, sf=2) - we can make this substitution outside of the per-op rule. a. constraints such as 'must evenly shard' can be moved onto the variable, and filtered against real tensor/mesh dims during expansion time 3. [**prototyped**]Focus Op Strategies on a single mesh-dim, and rely on standard infra to expand to the full mesh 4. [**prototyped**]Create the possible input-output sharding specs without looking at input args and doing premature filtering, which trades code complexity for runtime optimization. Try to regain the runtime savings in another common infra layer during 'expansion' to the mesh 4. [**prototyped**]Utilize a min-redistribution-cost priority queue to guide the expansion of the strategy and stop expansion after finding the first (lowest cost) redistribution of the inputs that matches an expansion of the strategy ### Experiment Approach This stack implements some of the above goals, and then implements some operators (so far, mm and pointwise). I'm working on figuring out what the right tests and metrics to collect are before going further. If you have concerns about whether this approach would work for a particular op, i could try that. * One major difference in this proposal is the 'premature optimization' in the strategy is removed. How much does this matter? Can we recover some of the time by doing filtering during mesh expansion, or by changing to expand only to some 'min cost' set of states around the input placements? -> Some preliminary benchmark results: https://gist.github.com/wconstab/5df63815696c504db2ffadfbf1675d21 * How can we test if the new strategy is correct? Can we tell if we're missing a strategy or if we're issuing a wrong strategy? [ghstack-poisoned]
**This diff stack is a prototype/experiment.** ### Motivation We find it is too difficult to implement sharding strategies for DTensor, and this slows progress towards full operator coverage, and increases likelihood of bugs in sharding rules. We also expect to add new Placement types (e.g. more complete support for StridedShard in the near term, but possibly others as well), and the current formulation of sharding strategies is not scalable to adding new placement types. ### Specific Goals 1. [**prototyped**]Op Strategies should automatically support StridedShard and any other 'purely sharding' placement type without additional work. Use a 'sharding placeholder' to declare sharding rules (e.g. S(0) , S(0) -> S(0) applies equally well to Shard(0) or StridedShard(0, sf=2) - we can make this substitution outside of the per-op rule. a. constraints such as 'must evenly shard' can be moved onto the variable, and filtered against real tensor/mesh dims during expansion time 3. [**prototyped**]Focus Op Strategies on a single mesh-dim, and rely on standard infra to expand to the full mesh 4. [**prototyped**]Create the possible input-output sharding specs without looking at input args and doing premature filtering, which trades code complexity for runtime optimization. Try to regain the runtime savings in another common infra layer during 'expansion' to the mesh 4. [**prototyped**]Utilize a min-redistribution-cost priority queue to guide the expansion of the strategy and stop expansion after finding the first (lowest cost) redistribution of the inputs that matches an expansion of the strategy ### Experiment Approach This stack implements some of the above goals, and then implements some operators (so far, mm and pointwise). I'm working on figuring out what the right tests and metrics to collect are before going further. If you have concerns about whether this approach would work for a particular op, i could try that. * One major difference in this proposal is the 'premature optimization' in the strategy is removed. How much does this matter? Can we recover some of the time by doing filtering during mesh expansion, or by changing to expand only to some 'min cost' set of states around the input placements? -> Some preliminary benchmark results: https://gist.github.com/wconstab/5df63815696c504db2ffadfbf1675d21 * How can we test if the new strategy is correct? Can we tell if we're missing a strategy or if we're issuing a wrong strategy? [ghstack-poisoned]
|
Starting merge as part of PR stack under #170359 |
This PR adds the register_single_dim_strategy util, and hooks it up to sharding_propagator. It also tests the registration. Notes: * I didn't yet decide how multiple registrations should be handled. I was planning to make it an error if you register twice for the same op for either single_dim or regular strategies. * I took the cleanest path of integration for now in sharding_prop, reusing as much code as possible with the existing 'op_strategy' case. I may have to change this later when integrating find_min_cost Pull Request resolved: #170359 Approved by: https://github.com/weifengpy ghstack dependencies: #170615, #167677 Co-authored-by: Pian Pawakapan <pianpwk@meta.com>
This reverts commit 32d0782. Reverted pytorch#170359 on behalf of https://github.com/jeanschmidt due to Required to revert pytorch#167677 that is required to revert pytorch#170615 that is required to revert pytorch#170030 ([comment](pytorch#170359 (comment)))
This reverts commit c3e628e. Reverted pytorch#167677 on behalf of https://github.com/jeanschmidt due to Required to revert pytorch#170615 that is required to rever pytorch#170030 ([comment](pytorch#167677 (comment)))
### Motivation We find it is too difficult to implement sharding strategies for DTensor, and this slows progress towards full operator coverage, and increases likelihood of bugs in sharding rules. We also expect to add new Placement types (e.g. more complete support for StridedShard in the near term, but possibly others as well), and the current formulation of sharding strategies is not scalable to adding new placement types. A primary reason it's so difficult to write sharding rules today is that they combine 2 things: (a) a mathematical description of which in/out shardings are correct for the op, (b) premature runtime optimization to avoid considering too many combinations over N-D mesh. The proposal is to remove (b) and focus just on (a). ### tl;dr 1. Write sharding prop rules in terms of 1 mesh dim and using a 'placeholder' for generic sharding types e.g. matmul rule returns: ``` [ ShardPlaceholder(0), Replicate() -> ShardPlaceholder(0), Replicate(), ShardPlaceholder(1), -> ShardPlaceholder(1), ShardPlaceholder(1), ShardPlaceholder(0) -> Partial() ] ``` 2. After registration, each rule gets automatically expanded to include the real sharding types discovered in the inputs at runtime, and add 'full replication' rule. e.g. if inputs are fully replicated, we drop the placeholder rules and only use ``` [ Replicate(), Replicate() -> Replicate() ] ``` if 'Shard' is discovered in inputs, we fill placeholders like ``` [ Shard(0), Replicate() -> Shard(0), Replicate(), Shard(1), -> Shard(1), Shard(1), Shard(0) -> Partial(), Replicate(), Replicate() -> Replicate() ] ``` if 'Shard' and 'StridedShard' are both discovered in the inputs, we expand to ``` [ Shard(0), Replicate() -> Shard(0), Replicate(), Shard(1), -> Shard(1), Shard(1), Shard(0) -> Partial(), StridedShard(0), Replicate() -> StridedShard(0), Replicate(), StridedShard(1), -> StridedShard(1), StridedShard(1), StridedShard(0) -> Partial(), Replicate(), Replicate() -> Replicate() ] ``` 3. After filling the placeholders, we expand to N-D mesh and find the minimum cost (a) full enumeration via itertools.product is implemented and gives exact parity with rules like 'einsum' today (b) optimized solution, starting from input placements and iterating in the order of increasing cost until reaching a min-cost solution _without_ having to fully enumerate - under development/prototyping ### This PR * defines a 'single_dim strategy' function and a ShardingPlaceholder * adds a util for expanding a single_dim strategy into a regular strategy * supports StridedShard automatically via ShardingPlaceholder expansion * writes rules for mm and cat and uses unit tests to validate their expansion ### Next Steps (PR stack) * Support pointwise and foreach ops in the single_dim infra * Hook up single-dim strategies to sharding_prop (op registration) * Start to use single_dim rules to replace existing rules * Improve the runtime of searching the fully expanded strategy * Explore using decomps together with single-dim rules to support more operators Pull Request resolved: pytorch#167677 Approved by: https://github.com/weifengpy ghstack dependencies: pytorch#170615 Co-authored-by: Pian Pawakapan <pianpwk@meta.com>
This PR adds the register_single_dim_strategy util, and hooks it up to sharding_propagator. It also tests the registration. Notes: * I didn't yet decide how multiple registrations should be handled. I was planning to make it an error if you register twice for the same op for either single_dim or regular strategies. * I took the cleanest path of integration for now in sharding_prop, reusing as much code as possible with the existing 'op_strategy' case. I may have to change this later when integrating find_min_cost Pull Request resolved: pytorch#170359 Approved by: https://github.com/weifengpy ghstack dependencies: pytorch#170615, pytorch#167677 Co-authored-by: Pian Pawakapan <pianpwk@meta.com>
Enforce tensor_meta is not none for new single-dim rules. Allow tensor_meta to continue to be None for existing rules for now. We should consider in the future asserting tensor_meta is required in DTensorSpec, but for now we just try to limit the bleeding. Pull Request resolved: #170827 Approved by: https://github.com/dolpm ghstack dependencies: #170615, #167677, #170359
This reverts commit 32d0782. Reverted #170359 on behalf of https://github.com/jeanschmidt due to Required to revert #167677 that is required to revert #170615 that is required to revert #170030 ([comment](#170359 (comment)))
This reverts commit c3e628e. Reverted #167677 on behalf of https://github.com/jeanschmidt due to Required to revert #170615 that is required to rever #170030 ([comment](#167677 (comment)))
### Motivation We find it is too difficult to implement sharding strategies for DTensor, and this slows progress towards full operator coverage, and increases likelihood of bugs in sharding rules. We also expect to add new Placement types (e.g. more complete support for StridedShard in the near term, but possibly others as well), and the current formulation of sharding strategies is not scalable to adding new placement types. A primary reason it's so difficult to write sharding rules today is that they combine 2 things: (a) a mathematical description of which in/out shardings are correct for the op, (b) premature runtime optimization to avoid considering too many combinations over N-D mesh. The proposal is to remove (b) and focus just on (a). ### tl;dr 1. Write sharding prop rules in terms of 1 mesh dim and using a 'placeholder' for generic sharding types e.g. matmul rule returns: ``` [ ShardPlaceholder(0), Replicate() -> ShardPlaceholder(0), Replicate(), ShardPlaceholder(1), -> ShardPlaceholder(1), ShardPlaceholder(1), ShardPlaceholder(0) -> Partial() ] ``` 2. After registration, each rule gets automatically expanded to include the real sharding types discovered in the inputs at runtime, and add 'full replication' rule. e.g. if inputs are fully replicated, we drop the placeholder rules and only use ``` [ Replicate(), Replicate() -> Replicate() ] ``` if 'Shard' is discovered in inputs, we fill placeholders like ``` [ Shard(0), Replicate() -> Shard(0), Replicate(), Shard(1), -> Shard(1), Shard(1), Shard(0) -> Partial(), Replicate(), Replicate() -> Replicate() ] ``` if 'Shard' and 'StridedShard' are both discovered in the inputs, we expand to ``` [ Shard(0), Replicate() -> Shard(0), Replicate(), Shard(1), -> Shard(1), Shard(1), Shard(0) -> Partial(), StridedShard(0), Replicate() -> StridedShard(0), Replicate(), StridedShard(1), -> StridedShard(1), StridedShard(1), StridedShard(0) -> Partial(), Replicate(), Replicate() -> Replicate() ] ``` 3. After filling the placeholders, we expand to N-D mesh and find the minimum cost (a) full enumeration via itertools.product is implemented and gives exact parity with rules like 'einsum' today (b) optimized solution, starting from input placements and iterating in the order of increasing cost until reaching a min-cost solution _without_ having to fully enumerate - under development/prototyping ### This PR * defines a 'single_dim strategy' function and a ShardingPlaceholder * adds a util for expanding a single_dim strategy into a regular strategy * supports StridedShard automatically via ShardingPlaceholder expansion * writes rules for mm and cat and uses unit tests to validate their expansion ### Next Steps (PR stack) * Support pointwise and foreach ops in the single_dim infra * Hook up single-dim strategies to sharding_prop (op registration) * Start to use single_dim rules to replace existing rules * Improve the runtime of searching the fully expanded strategy * Explore using decomps together with single-dim rules to support more operators Pull Request resolved: #167677 Approved by: https://github.com/weifengpy ghstack dependencies: #170615 Co-authored-by: Pian Pawakapan <pianpwk@meta.com>
This PR adds the register_single_dim_strategy util, and hooks it up to sharding_propagator. It also tests the registration. Notes: * I didn't yet decide how multiple registrations should be handled. I was planning to make it an error if you register twice for the same op for either single_dim or regular strategies. * I took the cleanest path of integration for now in sharding_prop, reusing as much code as possible with the existing 'op_strategy' case. I may have to change this later when integrating find_min_cost Pull Request resolved: #170359 Approved by: https://github.com/weifengpy ghstack dependencies: #170615, #167677 Co-authored-by: Pian Pawakapan <pianpwk@meta.com>
Enforce tensor_meta is not none for new single-dim rules. Allow tensor_meta to continue to be None for existing rules for now. We should consider in the future asserting tensor_meta is required in DTensorSpec, but for now we just try to limit the bleeding. Pull Request resolved: #170827 Approved by: https://github.com/dolpm ghstack dependencies: #170615, #167677, #170359
### Motivation We find it is too difficult to implement sharding strategies for DTensor, and this slows progress towards full operator coverage, and increases likelihood of bugs in sharding rules. We also expect to add new Placement types (e.g. more complete support for StridedShard in the near term, but possibly others as well), and the current formulation of sharding strategies is not scalable to adding new placement types. A primary reason it's so difficult to write sharding rules today is that they combine 2 things: (a) a mathematical description of which in/out shardings are correct for the op, (b) premature runtime optimization to avoid considering too many combinations over N-D mesh. The proposal is to remove (b) and focus just on (a). ### tl;dr 1. Write sharding prop rules in terms of 1 mesh dim and using a 'placeholder' for generic sharding types e.g. matmul rule returns: ``` [ ShardPlaceholder(0), Replicate() -> ShardPlaceholder(0), Replicate(), ShardPlaceholder(1), -> ShardPlaceholder(1), ShardPlaceholder(1), ShardPlaceholder(0) -> Partial() ] ``` 2. After registration, each rule gets automatically expanded to include the real sharding types discovered in the inputs at runtime, and add 'full replication' rule. e.g. if inputs are fully replicated, we drop the placeholder rules and only use ``` [ Replicate(), Replicate() -> Replicate() ] ``` if 'Shard' is discovered in inputs, we fill placeholders like ``` [ Shard(0), Replicate() -> Shard(0), Replicate(), Shard(1), -> Shard(1), Shard(1), Shard(0) -> Partial(), Replicate(), Replicate() -> Replicate() ] ``` if 'Shard' and 'StridedShard' are both discovered in the inputs, we expand to ``` [ Shard(0), Replicate() -> Shard(0), Replicate(), Shard(1), -> Shard(1), Shard(1), Shard(0) -> Partial(), StridedShard(0), Replicate() -> StridedShard(0), Replicate(), StridedShard(1), -> StridedShard(1), StridedShard(1), StridedShard(0) -> Partial(), Replicate(), Replicate() -> Replicate() ] ``` 3. After filling the placeholders, we expand to N-D mesh and find the minimum cost (a) full enumeration via itertools.product is implemented and gives exact parity with rules like 'einsum' today (b) optimized solution, starting from input placements and iterating in the order of increasing cost until reaching a min-cost solution _without_ having to fully enumerate - under development/prototyping ### This PR * defines a 'single_dim strategy' function and a ShardingPlaceholder * adds a util for expanding a single_dim strategy into a regular strategy * supports StridedShard automatically via ShardingPlaceholder expansion * writes rules for mm and cat and uses unit tests to validate their expansion ### Next Steps (PR stack) * Support pointwise and foreach ops in the single_dim infra * Hook up single-dim strategies to sharding_prop (op registration) * Start to use single_dim rules to replace existing rules * Improve the runtime of searching the fully expanded strategy * Explore using decomps together with single-dim rules to support more operators Pull Request resolved: pytorch#167677 Approved by: https://github.com/weifengpy ghstack dependencies: pytorch#170615 Co-authored-by: Pian Pawakapan <pianpwk@meta.com>
This reverts commit 32d0782. Reverted pytorch#170359 on behalf of https://github.com/jeanschmidt due to Required to revert pytorch#167677 that is required to revert pytorch#170615 that is required to revert pytorch#170030 ([comment](pytorch#170359 (comment)))
This reverts commit c3e628e. Reverted pytorch#167677 on behalf of https://github.com/jeanschmidt due to Required to revert pytorch#170615 that is required to rever pytorch#170030 ([comment](pytorch#167677 (comment)))
### Motivation We find it is too difficult to implement sharding strategies for DTensor, and this slows progress towards full operator coverage, and increases likelihood of bugs in sharding rules. We also expect to add new Placement types (e.g. more complete support for StridedShard in the near term, but possibly others as well), and the current formulation of sharding strategies is not scalable to adding new placement types. A primary reason it's so difficult to write sharding rules today is that they combine 2 things: (a) a mathematical description of which in/out shardings are correct for the op, (b) premature runtime optimization to avoid considering too many combinations over N-D mesh. The proposal is to remove (b) and focus just on (a). ### tl;dr 1. Write sharding prop rules in terms of 1 mesh dim and using a 'placeholder' for generic sharding types e.g. matmul rule returns: ``` [ ShardPlaceholder(0), Replicate() -> ShardPlaceholder(0), Replicate(), ShardPlaceholder(1), -> ShardPlaceholder(1), ShardPlaceholder(1), ShardPlaceholder(0) -> Partial() ] ``` 2. After registration, each rule gets automatically expanded to include the real sharding types discovered in the inputs at runtime, and add 'full replication' rule. e.g. if inputs are fully replicated, we drop the placeholder rules and only use ``` [ Replicate(), Replicate() -> Replicate() ] ``` if 'Shard' is discovered in inputs, we fill placeholders like ``` [ Shard(0), Replicate() -> Shard(0), Replicate(), Shard(1), -> Shard(1), Shard(1), Shard(0) -> Partial(), Replicate(), Replicate() -> Replicate() ] ``` if 'Shard' and 'StridedShard' are both discovered in the inputs, we expand to ``` [ Shard(0), Replicate() -> Shard(0), Replicate(), Shard(1), -> Shard(1), Shard(1), Shard(0) -> Partial(), StridedShard(0), Replicate() -> StridedShard(0), Replicate(), StridedShard(1), -> StridedShard(1), StridedShard(1), StridedShard(0) -> Partial(), Replicate(), Replicate() -> Replicate() ] ``` 3. After filling the placeholders, we expand to N-D mesh and find the minimum cost (a) full enumeration via itertools.product is implemented and gives exact parity with rules like 'einsum' today (b) optimized solution, starting from input placements and iterating in the order of increasing cost until reaching a min-cost solution _without_ having to fully enumerate - under development/prototyping ### This PR * defines a 'single_dim strategy' function and a ShardingPlaceholder * adds a util for expanding a single_dim strategy into a regular strategy * supports StridedShard automatically via ShardingPlaceholder expansion * writes rules for mm and cat and uses unit tests to validate their expansion ### Next Steps (PR stack) * Support pointwise and foreach ops in the single_dim infra * Hook up single-dim strategies to sharding_prop (op registration) * Start to use single_dim rules to replace existing rules * Improve the runtime of searching the fully expanded strategy * Explore using decomps together with single-dim rules to support more operators Pull Request resolved: pytorch#167677 Approved by: https://github.com/weifengpy ghstack dependencies: pytorch#170615 Co-authored-by: Pian Pawakapan <pianpwk@meta.com>
This PR adds the register_single_dim_strategy util, and hooks it up to sharding_propagator. It also tests the registration. Notes: * I didn't yet decide how multiple registrations should be handled. I was planning to make it an error if you register twice for the same op for either single_dim or regular strategies. * I took the cleanest path of integration for now in sharding_prop, reusing as much code as possible with the existing 'op_strategy' case. I may have to change this later when integrating find_min_cost Pull Request resolved: pytorch#170359 Approved by: https://github.com/weifengpy ghstack dependencies: pytorch#170615, pytorch#167677 Co-authored-by: Pian Pawakapan <pianpwk@meta.com>
Enforce tensor_meta is not none for new single-dim rules. Allow tensor_meta to continue to be None for existing rules for now. We should consider in the future asserting tensor_meta is required in DTensorSpec, but for now we just try to limit the bleeding. Pull Request resolved: pytorch#170827 Approved by: https://github.com/dolpm ghstack dependencies: pytorch#170615, pytorch#167677, pytorch#170359
ghstack-source-id: fed27a4 Pull Request resolved: pytorch/pytorch#167677
Following @tianyu-l's #130887 Adds support for ops with no sharding prop strategy, but a registered decomposition. Now if sharding prop sees a decomposable op, it: 1. Runs the decomposed op under a custom TorchDispatchMode, which propagates the placements as side information (initially used a make_fx implementation, but this required a threading lock as it relies on [global state](https://github.com/pytorch/pytorch/blob/2a26c9a32661ee2b4b049e3bd1b889fc3af30880/torch/fx/_symbolic_trace.py#L1167)) 2. Enumerates potential input placement combinations based on the actual input placements, on a single-dim mesh, then for each of them, propagates through torch_dispatch via sharding prop, while banning any intermediate redistributions. 3. Returns the expanded full-mesh strategy from the filtered strategies. Some caveats: - Since the dispatch mode runs sharding prop, the shard prop cache should kick in, both in the normal case (running the same op twice), and also when we recursively decompose (if op1 -> op2 -> some decomp, running op1 caches for op2). - One common failure case is decompositions calling factory methods (e.g. [torch.ones, torch.arange](https://github.com/pytorch/pytorch/blob/41f42a0fc3ea1fbfdf05b4c030d7df815bdfe19d/torch/_decomp/decompositions.py#L818-L821)). The main problem seems to be assigning placements to these tensors, and it's not so obvious what their placements should be, especially when they might take in sharded sizes, and we can't completely detect when this is the case. For now, intermediate shard prop will fail (no sharding strategy; they don't take DTensor inputs), but a potential future improvement is to permit the full-Replicate case for these graphs. - Sharding prop is currently via a `propagate_op_sharding` call, on explicit placement types. Once [single-dim strategy](#167677) coverage is broader, this should be doable on _ShardPlaceholders instead, making the enumeration & propagation process cheaper, though maybe more manual. - (Maybe hackily) uses a fake 1-rank 1d mesh to do single-dim propagation Removes the following xfails (+some more aten ops with decomp coverage, but still failing tests): ``` __rsub__ addmv addr alias_copy all any count_nonzero dist expand_copy fill floor_divide index_select linalg.vecdot masked_fill mv nn.functional.celu nn.functional.channel_shuffle nn.functional.elu nn.functional.hardsigmoid nn.functional.hardswish nn.functional.hardtanh nn.functional.leaky_relu nn.functional.logsigmoid nn.functional.margin_ranking_loss nn.functional.mish nn.functional.multilabel_soft_margin_loss nn.functional.pairwise_distance nn.functional.pixel_shuffle nn.functional.pixel_unshuffle nn.functional.prelu nn.functional.relu6 nn.functional.selu nn.functional.softplus nn.functional.softshrink nn.functional.triplet_margin_loss nn.functional.triplet_margin_with_distance_loss permute_copy rsub t_copy trace vdot view_copy ``` Pull Request resolved: #171652 Approved by: https://github.com/wconstab
Stack from ghstack (oldest at bottom):
Motivation
We find it is too difficult to implement sharding strategies for DTensor, and this slows progress towards full operator coverage, and increases likelihood of bugs in sharding rules. We also expect to add new Placement types (e.g. more complete support for StridedShard in the near term, but possibly others as well), and the current formulation of sharding strategies is not scalable to adding new placement types.
A primary reason it's so difficult to write sharding rules today is that they combine 2 things: (a) a mathematical description of which in/out shardings are correct for the op, (b) premature runtime optimization to avoid considering too many combinations over N-D mesh.
The proposal is to remove (b) and focus just on (a).
tl;dr
e.g. matmul rule returns:
e.g.
if inputs are fully replicated, we drop the placeholder rules and only use
if 'Shard' is discovered in inputs, we fill placeholders like
if 'Shard' and 'StridedShard' are both discovered in the inputs, we expand to
(a) full enumeration via itertools.product is implemented and gives exact parity with rules like 'einsum' today
(b) optimized solution, starting from input placements and iterating in the order of increasing cost until reaching a min-cost solution without having to fully enumerate - under development/prototyping
This PR
Next Steps (PR stack)
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @d4l3k @pragupta @msaroufim @dcci @tianyu-l @XilunWu @SherlockNoMad @ppwwyyxx