[DTensor] decomposed sharding propagation#130887
[DTensor] decomposed sharding propagation#130887tianyu-l wants to merge 2 commits intogh/tianyu-l/2/basefrom
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/130887
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New Failures, 1 Cancelled JobAs of commit 7446907 with merge base df59193 ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR adds the feature of sharding propagation via op decomposition. #TODO: summary to be added cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
wanchaol
left a comment
There was a problem hiding this comment.
Nice work! This looks reasonably good already, only have some minor comments
| @@ -0,0 +1,26 @@ | |||
| # mypy: allow-untyped-defs | |||
There was a problem hiding this comment.
Please rebase and make this a private module
| LINEAR_REDUCTION_OP_MAP = { | ||
| aten.all.default: "sum", | ||
| aten.all.dim: "sum", | ||
| aten.amax.default: "max", |
There was a problem hiding this comment.
there should be some test can be enabled in test_dtensor_ops.py given that we enabled additional ops here?
| for strtg in node_output_strategy.strategies: | ||
| if strtg.input_specs is None: | ||
| assert isinstance(strtg.output_specs, DTensorSpec) | ||
| for idx, input_strtg in enumerate( |
There was a problem hiding this comment.
Let's just name this as input_strategy as it's not getting shortened that much
| node_to_spec[node] = node_output_spec | ||
| elif node.op == "output": | ||
| output_node = node.args[0] | ||
| graph_output_specs = [node_to_spec[node] for node in output_node] |
There was a problem hiding this comment.
hmmm I think here you only handled the case when the output is a list of tensors, we should probably handle for the cases where if the output is a single tensor, a tuple of tensors too, you can refer to the wrap_output_spec/wrap logic to see how to handle those cases.
There was a problem hiding this comment.
IIRC this handles both cases of single tensor and tuple of tensors. See tests on both aten.aminmax (tuple of two tensors) and aten._log_softmax (single tensor). In other words, the output_node would a list of results (possibly singleton) regardless the designated output type of the function.
| ) | ||
| all_possible_schema.append(possible_arg_specs) | ||
| else: | ||
| all_possible_schema.append((arg_spec,)) |
There was a problem hiding this comment.
I guess the reason it appends a tuple here for non-tensor arg is to allow product later?
| y_dt = torch.nn.functional.log_softmax(x_dt, dim=softmax_dim) | ||
|
|
||
| self.assertTrue(y_dt.placements[0].is_replicate()) | ||
| # TODO(lty): numerical test doesn't work -- similar to the complex mul bug |
There was a problem hiding this comment.
hmmm I wonder why? iirc the complex mul bug is specific to handling complex numbers, but softmax/log_softmax does not involve complex numbers?
There was a problem hiding this comment.
looks like you are comparing numerics for log_softmax and regular softmax -- if they are both log this seems fine.
|
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
|
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
|
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
|
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
|
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
Following @tianyu-l's #130887 Adds support for ops with no sharding prop strategy, but a registered decomposition. Now if sharding prop sees a decomposable op, it: 1. Runs the decomposed op under a custom TorchDispatchMode, which propagates the placements as side information (initially used a make_fx implementation, but this required a threading lock as it relies on [global state](https://github.com/pytorch/pytorch/blob/2a26c9a32661ee2b4b049e3bd1b889fc3af30880/torch/fx/_symbolic_trace.py#L1167)) 2. Enumerates potential input placement combinations based on the actual input placements, on a single-dim mesh, then for each of them, propagates through torch_dispatch via sharding prop, while banning any intermediate redistributions. 3. Returns the expanded full-mesh strategy from the filtered strategies. Some caveats: - Since the dispatch mode runs sharding prop, the shard prop cache should kick in, both in the normal case (running the same op twice), and also when we recursively decompose (if op1 -> op2 -> some decomp, running op1 caches for op2). - One common failure case is decompositions calling factory methods (e.g. [torch.ones, torch.arange](https://github.com/pytorch/pytorch/blob/41f42a0fc3ea1fbfdf05b4c030d7df815bdfe19d/torch/_decomp/decompositions.py#L818-L821)). The main problem seems to be assigning placements to these tensors, and it's not so obvious what their placements should be, especially when they might take in sharded sizes, and we can't completely detect when this is the case. For now, intermediate shard prop will fail (no sharding strategy; they don't take DTensor inputs), but a potential future improvement is to permit the full-Replicate case for these graphs. - Sharding prop is currently via a `propagate_op_sharding` call, on explicit placement types. Once [single-dim strategy](#167677) coverage is broader, this should be doable on _ShardPlaceholders instead, making the enumeration & propagation process cheaper, though maybe more manual. - (Maybe hackily) uses a fake 1-rank 1d mesh to do single-dim propagation Removes the following xfails (+some more aten ops with decomp coverage, but still failing tests): ``` __rsub__ addmv addr alias_copy all any count_nonzero dist expand_copy fill floor_divide index_select linalg.vecdot masked_fill mv nn.functional.celu nn.functional.channel_shuffle nn.functional.elu nn.functional.hardsigmoid nn.functional.hardswish nn.functional.hardtanh nn.functional.leaky_relu nn.functional.logsigmoid nn.functional.margin_ranking_loss nn.functional.mish nn.functional.multilabel_soft_margin_loss nn.functional.pairwise_distance nn.functional.pixel_shuffle nn.functional.pixel_unshuffle nn.functional.prelu nn.functional.relu6 nn.functional.selu nn.functional.softplus nn.functional.softshrink nn.functional.triplet_margin_loss nn.functional.triplet_margin_with_distance_loss permute_copy rsub t_copy trace vdot view_copy ``` Pull Request resolved: #171652 Approved by: https://github.com/wconstab
Following @tianyu-l's pytorch#130887 Adds support for ops with no sharding prop strategy, but a registered decomposition. Now if sharding prop sees a decomposable op, it: 1. Runs the decomposed op under a custom TorchDispatchMode, which propagates the placements as side information (initially used a make_fx implementation, but this required a threading lock as it relies on [global state](https://github.com/pytorch/pytorch/blob/2a26c9a32661ee2b4b049e3bd1b889fc3af30880/torch/fx/_symbolic_trace.py#L1167)) 2. Enumerates potential input placement combinations based on the actual input placements, on a single-dim mesh, then for each of them, propagates through torch_dispatch via sharding prop, while banning any intermediate redistributions. 3. Returns the expanded full-mesh strategy from the filtered strategies. Some caveats: - Since the dispatch mode runs sharding prop, the shard prop cache should kick in, both in the normal case (running the same op twice), and also when we recursively decompose (if op1 -> op2 -> some decomp, running op1 caches for op2). - One common failure case is decompositions calling factory methods (e.g. [torch.ones, torch.arange](https://github.com/pytorch/pytorch/blob/41f42a0fc3ea1fbfdf05b4c030d7df815bdfe19d/torch/_decomp/decompositions.py#L818-L821)). The main problem seems to be assigning placements to these tensors, and it's not so obvious what their placements should be, especially when they might take in sharded sizes, and we can't completely detect when this is the case. For now, intermediate shard prop will fail (no sharding strategy; they don't take DTensor inputs), but a potential future improvement is to permit the full-Replicate case for these graphs. - Sharding prop is currently via a `propagate_op_sharding` call, on explicit placement types. Once [single-dim strategy](pytorch#167677) coverage is broader, this should be doable on _ShardPlaceholders instead, making the enumeration & propagation process cheaper, though maybe more manual. - (Maybe hackily) uses a fake 1-rank 1d mesh to do single-dim propagation Removes the following xfails (+some more aten ops with decomp coverage, but still failing tests): ``` __rsub__ addmv addr alias_copy all any count_nonzero dist expand_copy fill floor_divide index_select linalg.vecdot masked_fill mv nn.functional.celu nn.functional.channel_shuffle nn.functional.elu nn.functional.hardsigmoid nn.functional.hardswish nn.functional.hardtanh nn.functional.leaky_relu nn.functional.logsigmoid nn.functional.margin_ranking_loss nn.functional.mish nn.functional.multilabel_soft_margin_loss nn.functional.pairwise_distance nn.functional.pixel_shuffle nn.functional.pixel_unshuffle nn.functional.prelu nn.functional.relu6 nn.functional.selu nn.functional.softplus nn.functional.softshrink nn.functional.triplet_margin_loss nn.functional.triplet_margin_with_distance_loss permute_copy rsub t_copy trace vdot view_copy ``` Pull Request resolved: pytorch#171652 Approved by: https://github.com/wconstab
Following @tianyu-l's pytorch#130887 Adds support for ops with no sharding prop strategy, but a registered decomposition. Now if sharding prop sees a decomposable op, it: 1. Runs the decomposed op under a custom TorchDispatchMode, which propagates the placements as side information (initially used a make_fx implementation, but this required a threading lock as it relies on [global state](https://github.com/pytorch/pytorch/blob/2a26c9a32661ee2b4b049e3bd1b889fc3af30880/torch/fx/_symbolic_trace.py#L1167)) 2. Enumerates potential input placement combinations based on the actual input placements, on a single-dim mesh, then for each of them, propagates through torch_dispatch via sharding prop, while banning any intermediate redistributions. 3. Returns the expanded full-mesh strategy from the filtered strategies. Some caveats: - Since the dispatch mode runs sharding prop, the shard prop cache should kick in, both in the normal case (running the same op twice), and also when we recursively decompose (if op1 -> op2 -> some decomp, running op1 caches for op2). - One common failure case is decompositions calling factory methods (e.g. [torch.ones, torch.arange](https://github.com/pytorch/pytorch/blob/41f42a0fc3ea1fbfdf05b4c030d7df815bdfe19d/torch/_decomp/decompositions.py#L818-L821)). The main problem seems to be assigning placements to these tensors, and it's not so obvious what their placements should be, especially when they might take in sharded sizes, and we can't completely detect when this is the case. For now, intermediate shard prop will fail (no sharding strategy; they don't take DTensor inputs), but a potential future improvement is to permit the full-Replicate case for these graphs. - Sharding prop is currently via a `propagate_op_sharding` call, on explicit placement types. Once [single-dim strategy](pytorch#167677) coverage is broader, this should be doable on _ShardPlaceholders instead, making the enumeration & propagation process cheaper, though maybe more manual. - (Maybe hackily) uses a fake 1-rank 1d mesh to do single-dim propagation Removes the following xfails (+some more aten ops with decomp coverage, but still failing tests): ``` __rsub__ addmv addr alias_copy all any count_nonzero dist expand_copy fill floor_divide index_select linalg.vecdot masked_fill mv nn.functional.celu nn.functional.channel_shuffle nn.functional.elu nn.functional.hardsigmoid nn.functional.hardswish nn.functional.hardtanh nn.functional.leaky_relu nn.functional.logsigmoid nn.functional.margin_ranking_loss nn.functional.mish nn.functional.multilabel_soft_margin_loss nn.functional.pairwise_distance nn.functional.pixel_shuffle nn.functional.pixel_unshuffle nn.functional.prelu nn.functional.relu6 nn.functional.selu nn.functional.softplus nn.functional.softshrink nn.functional.triplet_margin_loss nn.functional.triplet_margin_with_distance_loss permute_copy rsub t_copy trace vdot view_copy ``` Pull Request resolved: pytorch#171652 Approved by: https://github.com/wconstab
Stack from ghstack (oldest at bottom):
This PR adds the feature of sharding propagation via op decomposition.
#TODO: summary to be added
cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o