Skip to content

[DTensor] decomposed sharding propagation#130887

Draft
tianyu-l wants to merge 2 commits intogh/tianyu-l/2/basefrom
gh/tianyu-l/2/head
Draft

[DTensor] decomposed sharding propagation#130887
tianyu-l wants to merge 2 commits intogh/tianyu-l/2/basefrom
gh/tianyu-l/2/head

Conversation

@tianyu-l
Copy link
Contributor

@tianyu-l tianyu-l commented Jul 17, 2024

Stack from ghstack (oldest at bottom):

This PR adds the feature of sharding propagation via op decomposition.

#TODO: summary to be added

cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o

@pytorch-bot
Copy link

pytorch-bot bot commented Jul 17, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/130887

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures, 1 Cancelled Job

As of commit 7446907 with merge base df59193 (image):

NEW FAILURES - The following jobs have failed:

CANCELLED JOB - The following job was cancelled. Please retry:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/inductor oncall: distributed Add this issue/PR to distributed oncall triage queue labels Jul 17, 2024
tianyu-l added a commit that referenced this pull request Jul 17, 2024
ghstack-source-id: 668ea73
Pull Request resolved: #130887
@tianyu-l tianyu-l requested a review from wanchaol July 17, 2024 04:37
@tianyu-l tianyu-l marked this pull request as draft July 17, 2024 04:41
This PR adds the feature of sharding propagation via op decomposition.

#TODO: summary to be added

cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
tianyu-l added a commit that referenced this pull request Jul 19, 2024
ghstack-source-id: a7ea252
Pull Request resolved: #130887
Copy link
Collaborator

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work! This looks reasonably good already, only have some minor comments

@@ -0,0 +1,26 @@
# mypy: allow-untyped-defs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please rebase and make this a private module

LINEAR_REDUCTION_OP_MAP = {
aten.all.default: "sum",
aten.all.dim: "sum",
aten.amax.default: "max",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there should be some test can be enabled in test_dtensor_ops.py given that we enabled additional ops here?

for strtg in node_output_strategy.strategies:
if strtg.input_specs is None:
assert isinstance(strtg.output_specs, DTensorSpec)
for idx, input_strtg in enumerate(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just name this as input_strategy as it's not getting shortened that much

node_to_spec[node] = node_output_spec
elif node.op == "output":
output_node = node.args[0]
graph_output_specs = [node_to_spec[node] for node in output_node]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm I think here you only handled the case when the output is a list of tensors, we should probably handle for the cases where if the output is a single tensor, a tuple of tensors too, you can refer to the wrap_output_spec/wrap logic to see how to handle those cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC this handles both cases of single tensor and tuple of tensors. See tests on both aten.aminmax (tuple of two tensors) and aten._log_softmax (single tensor). In other words, the output_node would a list of results (possibly singleton) regardless the designated output type of the function.

)
all_possible_schema.append(possible_arg_specs)
else:
all_possible_schema.append((arg_spec,))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the reason it appends a tuple here for non-tensor arg is to allow product later?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

y_dt = torch.nn.functional.log_softmax(x_dt, dim=softmax_dim)

self.assertTrue(y_dt.placements[0].is_replicate())
# TODO(lty): numerical test doesn't work -- similar to the complex mul bug
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm I wonder why? iirc the complex mul bug is specific to handling complex numbers, but softmax/log_softmax does not involve complex numbers?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like you are comparing numerics for log_softmax and regular softmax -- if they are both log this seems fine.

@github-actions
Copy link
Contributor

github-actions bot commented Oct 5, 2024

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Oct 5, 2024
@tianyu-l tianyu-l removed the Stale label Oct 7, 2024
@github-actions
Copy link
Contributor

github-actions bot commented Dec 6, 2024

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Dec 6, 2024
@tianyu-l tianyu-l removed the Stale label Dec 6, 2024
@github-actions
Copy link
Contributor

github-actions bot commented Feb 4, 2025

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions
Copy link
Contributor

github-actions bot commented Apr 6, 2025

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Apr 6, 2025
@tianyu-l tianyu-l removed the Stale label Apr 6, 2025
@github-actions
Copy link
Contributor

github-actions bot commented Jun 5, 2025

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

pytorchmergebot pushed a commit that referenced this pull request Feb 4, 2026
Following @tianyu-l's #130887

Adds support for ops with no sharding prop strategy, but a registered decomposition.

Now if sharding prop sees a decomposable op, it:
1. Runs the decomposed op under a custom TorchDispatchMode, which propagates the placements as side information (initially used a make_fx implementation, but this required a threading lock as it relies on [global state](https://github.com/pytorch/pytorch/blob/2a26c9a32661ee2b4b049e3bd1b889fc3af30880/torch/fx/_symbolic_trace.py#L1167))
2. Enumerates potential input placement combinations based on the actual input placements, on a single-dim mesh, then for each of them, propagates through torch_dispatch via sharding prop, while banning any intermediate redistributions.
3. Returns the expanded full-mesh strategy from the filtered strategies.

Some caveats:
- Since the dispatch mode runs sharding prop, the shard prop cache should kick in, both in the normal case (running the same op twice), and also when we recursively decompose (if op1 -> op2 -> some decomp, running op1 caches for op2).
- One common failure case is decompositions calling factory methods (e.g. [torch.ones, torch.arange](https://github.com/pytorch/pytorch/blob/41f42a0fc3ea1fbfdf05b4c030d7df815bdfe19d/torch/_decomp/decompositions.py#L818-L821)). The main problem seems to be assigning placements to these tensors, and it's not so obvious what their placements should be, especially when they might take in sharded sizes, and we can't completely detect when this is the case. For now, intermediate shard prop will fail (no sharding strategy; they don't take DTensor inputs), but a potential future improvement is to permit the full-Replicate case for these graphs.
- Sharding prop is currently via a `propagate_op_sharding` call, on explicit placement types. Once [single-dim strategy](#167677) coverage is broader, this should be doable on _ShardPlaceholders instead, making the enumeration & propagation process cheaper, though maybe more manual.
- (Maybe hackily) uses a fake 1-rank 1d mesh to do single-dim propagation

Removes the following xfails (+some more aten ops with decomp coverage, but still failing tests):
```
__rsub__
addmv
addr
alias_copy
all
any
count_nonzero
dist
expand_copy
fill
floor_divide
index_select
linalg.vecdot
masked_fill
mv
nn.functional.celu
nn.functional.channel_shuffle
nn.functional.elu
nn.functional.hardsigmoid
nn.functional.hardswish
nn.functional.hardtanh
nn.functional.leaky_relu
nn.functional.logsigmoid
nn.functional.margin_ranking_loss
nn.functional.mish
nn.functional.multilabel_soft_margin_loss
nn.functional.pairwise_distance
nn.functional.pixel_shuffle
nn.functional.pixel_unshuffle
nn.functional.prelu
nn.functional.relu6
nn.functional.selu
nn.functional.softplus
nn.functional.softshrink
nn.functional.triplet_margin_loss
nn.functional.triplet_margin_with_distance_loss
permute_copy
rsub
t_copy
trace
vdot
view_copy
```
Pull Request resolved: #171652
Approved by: https://github.com/wconstab
radeksm pushed a commit to radeksm/pytorch that referenced this pull request Feb 20, 2026
Following @tianyu-l's pytorch#130887

Adds support for ops with no sharding prop strategy, but a registered decomposition.

Now if sharding prop sees a decomposable op, it:
1. Runs the decomposed op under a custom TorchDispatchMode, which propagates the placements as side information (initially used a make_fx implementation, but this required a threading lock as it relies on [global state](https://github.com/pytorch/pytorch/blob/2a26c9a32661ee2b4b049e3bd1b889fc3af30880/torch/fx/_symbolic_trace.py#L1167))
2. Enumerates potential input placement combinations based on the actual input placements, on a single-dim mesh, then for each of them, propagates through torch_dispatch via sharding prop, while banning any intermediate redistributions.
3. Returns the expanded full-mesh strategy from the filtered strategies.

Some caveats:
- Since the dispatch mode runs sharding prop, the shard prop cache should kick in, both in the normal case (running the same op twice), and also when we recursively decompose (if op1 -> op2 -> some decomp, running op1 caches for op2).
- One common failure case is decompositions calling factory methods (e.g. [torch.ones, torch.arange](https://github.com/pytorch/pytorch/blob/41f42a0fc3ea1fbfdf05b4c030d7df815bdfe19d/torch/_decomp/decompositions.py#L818-L821)). The main problem seems to be assigning placements to these tensors, and it's not so obvious what their placements should be, especially when they might take in sharded sizes, and we can't completely detect when this is the case. For now, intermediate shard prop will fail (no sharding strategy; they don't take DTensor inputs), but a potential future improvement is to permit the full-Replicate case for these graphs.
- Sharding prop is currently via a `propagate_op_sharding` call, on explicit placement types. Once [single-dim strategy](pytorch#167677) coverage is broader, this should be doable on _ShardPlaceholders instead, making the enumeration & propagation process cheaper, though maybe more manual.
- (Maybe hackily) uses a fake 1-rank 1d mesh to do single-dim propagation

Removes the following xfails (+some more aten ops with decomp coverage, but still failing tests):
```
__rsub__
addmv
addr
alias_copy
all
any
count_nonzero
dist
expand_copy
fill
floor_divide
index_select
linalg.vecdot
masked_fill
mv
nn.functional.celu
nn.functional.channel_shuffle
nn.functional.elu
nn.functional.hardsigmoid
nn.functional.hardswish
nn.functional.hardtanh
nn.functional.leaky_relu
nn.functional.logsigmoid
nn.functional.margin_ranking_loss
nn.functional.mish
nn.functional.multilabel_soft_margin_loss
nn.functional.pairwise_distance
nn.functional.pixel_shuffle
nn.functional.pixel_unshuffle
nn.functional.prelu
nn.functional.relu6
nn.functional.selu
nn.functional.softplus
nn.functional.softshrink
nn.functional.triplet_margin_loss
nn.functional.triplet_margin_with_distance_loss
permute_copy
rsub
t_copy
trace
vdot
view_copy
```
Pull Request resolved: pytorch#171652
Approved by: https://github.com/wconstab
libohao1201 pushed a commit to libohao1201/pytorch that referenced this pull request Mar 2, 2026
Following @tianyu-l's pytorch#130887

Adds support for ops with no sharding prop strategy, but a registered decomposition.

Now if sharding prop sees a decomposable op, it:
1. Runs the decomposed op under a custom TorchDispatchMode, which propagates the placements as side information (initially used a make_fx implementation, but this required a threading lock as it relies on [global state](https://github.com/pytorch/pytorch/blob/2a26c9a32661ee2b4b049e3bd1b889fc3af30880/torch/fx/_symbolic_trace.py#L1167))
2. Enumerates potential input placement combinations based on the actual input placements, on a single-dim mesh, then for each of them, propagates through torch_dispatch via sharding prop, while banning any intermediate redistributions.
3. Returns the expanded full-mesh strategy from the filtered strategies.

Some caveats:
- Since the dispatch mode runs sharding prop, the shard prop cache should kick in, both in the normal case (running the same op twice), and also when we recursively decompose (if op1 -> op2 -> some decomp, running op1 caches for op2).
- One common failure case is decompositions calling factory methods (e.g. [torch.ones, torch.arange](https://github.com/pytorch/pytorch/blob/41f42a0fc3ea1fbfdf05b4c030d7df815bdfe19d/torch/_decomp/decompositions.py#L818-L821)). The main problem seems to be assigning placements to these tensors, and it's not so obvious what their placements should be, especially when they might take in sharded sizes, and we can't completely detect when this is the case. For now, intermediate shard prop will fail (no sharding strategy; they don't take DTensor inputs), but a potential future improvement is to permit the full-Replicate case for these graphs.
- Sharding prop is currently via a `propagate_op_sharding` call, on explicit placement types. Once [single-dim strategy](pytorch#167677) coverage is broader, this should be doable on _ShardPlaceholders instead, making the enumeration & propagation process cheaper, though maybe more manual.
- (Maybe hackily) uses a fake 1-rank 1d mesh to do single-dim propagation

Removes the following xfails (+some more aten ops with decomp coverage, but still failing tests):
```
__rsub__
addmv
addr
alias_copy
all
any
count_nonzero
dist
expand_copy
fill
floor_divide
index_select
linalg.vecdot
masked_fill
mv
nn.functional.celu
nn.functional.channel_shuffle
nn.functional.elu
nn.functional.hardsigmoid
nn.functional.hardswish
nn.functional.hardtanh
nn.functional.leaky_relu
nn.functional.logsigmoid
nn.functional.margin_ranking_loss
nn.functional.mish
nn.functional.multilabel_soft_margin_loss
nn.functional.pairwise_distance
nn.functional.pixel_shuffle
nn.functional.pixel_unshuffle
nn.functional.prelu
nn.functional.relu6
nn.functional.selu
nn.functional.softplus
nn.functional.softshrink
nn.functional.triplet_margin_loss
nn.functional.triplet_margin_with_distance_loss
permute_copy
rsub
t_copy
trace
vdot
view_copy
```
Pull Request resolved: pytorch#171652
Approved by: https://github.com/wconstab
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor no-stale oncall: distributed Add this issue/PR to distributed oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants