[DTensor] fix max.dim/min.dim strategy#175776
[DTensor] fix max.dim/min.dim strategy#175776pianpwk wants to merge 5 commits intogh/pianpwk/102/basefrom
Conversation
The previous strategy allowed Partial("max"/"min") on values when the
input was sharded on the reduction dim. While Partial is valid for
values, the indices are local to each shard and cannot be combined
across ranks — producing incorrect global indices.
Rewrite as a single_dim_strategy that only allows sharding on
non-reduction dims, forcing Replicate on the reduction dim so both
values and indices are computed correctly.
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/175776
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (4 Unrelated Failures)As of commit feea529 with merge base ea9fce2 ( FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| @register_single_dim_strategy( | ||
| [aten.max.dim, aten.min.dim], schema_info=RuntimeSchemaInfo(1) | ||
| ) | ||
| def max_min_dim_single_dim_strategy( |
There was a problem hiding this comment.
@anshul-si i'm deferring to you on this, i think your PRs do not touch max.dim and min.dim so it is OK to land this first? Not sure if you were planning to work on the ops in _math_ops at some point?
There was a problem hiding this comment.
this can be landed first. i was planning on working on op in math_ops after pointwise_ops, but this can be used to help me
| _ShardingPlaceholder(d), | ||
| ] | ||
| ) | ||
| return strategies |
There was a problem hiding this comment.
aren't we forgetting some Partial prop from this rule?
There was a problem hiding this comment.
I thought so too, but max(P(max)) -> P(max) values, but invalid indices, so we can't, regardless of reduction dimensions
There was a problem hiding this comment.
i see, we definitely can't return partial indices. i missed that indices was a return value despite your comment above. makes sense.
wconstab
left a comment
There was a problem hiding this comment.
does the sharding validator run on this op + PR?
|
@wconstab I'm not sure about the result, but these seem to fall under the reduction_with_dim variant, and I don't see any missing rules for max/min.dim: |
wconstab
left a comment
There was a problem hiding this comment.
LGTM, i was confused but your impl is right. no partials should be supported for max.dim variant.
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: trunk / macos-py3-arm64 / test (mps, 1, 1, macos-m1-14) Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 2 checks: inductor / unit-test / inductor-test / test (inductor, 2, 2, linux.g5.4xlarge.nvidia.gpu), trunk / macos-py3-arm64 / test (mps, 1, 1, macos-m1-14) Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
@pytorchbot revert -m "Looks like introduced some new distributed breakages, see https://hud.pytorch.org/hud/pytorch/pytorch/1b9046a794cd2f8d882adf47d5612407cf43c1d2/1?per_page=50&name_filter=test%20(distr&mergeEphemeralLF=true" -c nosignal |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
@pytorchbot revert -m "I'm not sure what's going on, but it breaks lint this time around, see https://hud.pytorch.org/hud/pytorch/pytorch/7c8edff72bcf501f2fb70a4b3149718a905c2471/1?per_page=50&name_filter=lint&mergeEphemeralLF=true" -c nosignal |
|
@pytorchbot successfully started a revert job. Check the current status here. |
This reverts commit 2f0a6bd. Reverted #175776 on behalf of https://github.com/malfet due to I'm not sure what's going on, but it breaks lint this time around, see https://hud.pytorch.org/hud/pytorch/pytorch/7c8edff72bcf501f2fb70a4b3149718a905c2471/1?per_page=50&name_filter=lint&mergeEphemeralLF=true ([comment](#175776 (comment)))
|
@pianpwk your PR has been successfully reverted. |
|
It's because of a land race with the PR that enables lint for plain assert. |
aten.max/min.dim returns (values, indices), and strategies currently allow S(reduction_dim) -> P(max/min), P(max/min). This is invalid for indices, and we should ban sharding on the reduction dim. Rewrites as a single-dim strategy. Pull Request resolved: pytorch#175776 Approved by: https://github.com/wconstab
This reverts commit 2f0a6bd. Reverted pytorch#175776 on behalf of https://github.com/malfet due to I'm not sure what's going on, but it breaks lint this time around, see https://hud.pytorch.org/hud/pytorch/pytorch/7c8edff72bcf501f2fb70a4b3149718a905c2471/1?per_page=50&name_filter=lint&mergeEphemeralLF=true ([comment](pytorch#175776 (comment)))
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
aten.max/min.dim returns (values, indices), and strategies currently allow S(reduction_dim) -> P(max/min), P(max/min). This is invalid for indices, and we should ban sharding on the reduction dim. Rewrites as a single-dim strategy. Pull Request resolved: pytorch#175776 Approved by: https://github.com/wconstab
Upstreaming from autoparallel: https://github.com/meta-pytorch/autoparallel/blob/454780d2a27456a380c0d8e997c8fc2cf82ef5d8/autoparallel/shardings/propagation_rules.py#L630 The previous strategy required full-Replicate: we can passthrough on non-padded dims, and allow Partial inputs when pad value = 0 (arguable if we should fix this). Rewritten as a single-dim strategy Pull Request resolved: #175656 Approved by: https://github.com/wconstab ghstack dependencies: #175776
Upstreaming from autoparallel: https://github.com/meta-pytorch/autoparallel/blob/454780d2a27456a380c0d8e997c8fc2cf82ef5d8/autoparallel/shardings/propagation_rules.py#L630 The previous strategy required full-Replicate: we can passthrough on non-padded dims, and allow Partial inputs when pad value = 0 (arguable if we should fix this). Rewritten as a single-dim strategy Pull Request resolved: pytorch#175656 Approved by: https://github.com/wconstab ghstack dependencies: pytorch#175776
The previous strategy allowed Partial("max"/"min") on values when the
input was sharded on the reduction dim. While Partial is valid for
values, the indices are local to each shard and cannot be combined
across ranks — producing incorrect global indices.
Rewrite as a single_dim_strategy that only allows sharding on
non-reduction dims, forcing Replicate on the reduction dim so both
values and indices are computed correctly.
ghstack-source-id: b7b09ff
Pull Request resolved: pytorch/pytorch#175776
aten.max/min.dim returns (values, indices), and strategies currently allow S(reduction_dim) -> P(max/min), P(max/min). This is invalid for indices, and we should ban sharding on the reduction dim. Rewrites as a single-dim strategy. Pull Request resolved: pytorch#175776 Approved by: https://github.com/wconstab
This reverts commit 07b82e8. Reverted pytorch#175776 on behalf of https://github.com/malfet due to Looks like introduced some new distributed breakages, see https://hud.pytorch.org/hud/pytorch/pytorch/1b9046a794cd2f8d882adf47d5612407cf43c1d2/1?per_page=50&name_filter=test%20(distr&mergeEphemeralLF=true ([comment](pytorch#175776 (comment)))
aten.max/min.dim returns (values, indices), and strategies currently allow S(reduction_dim) -> P(max/min), P(max/min). This is invalid for indices, and we should ban sharding on the reduction dim. Rewrites as a single-dim strategy. Pull Request resolved: pytorch#175776 Approved by: https://github.com/wconstab
This reverts commit 2f0a6bd. Reverted pytorch#175776 on behalf of https://github.com/malfet due to I'm not sure what's going on, but it breaks lint this time around, see https://hud.pytorch.org/hud/pytorch/pytorch/7c8edff72bcf501f2fb70a4b3149718a905c2471/1?per_page=50&name_filter=lint&mergeEphemeralLF=true ([comment](pytorch#175776 (comment)))
aten.max/min.dim returns (values, indices), and strategies currently allow S(reduction_dim) -> P(max/min), P(max/min). This is invalid for indices, and we should ban sharding on the reduction dim. Rewrites as a single-dim strategy. Pull Request resolved: pytorch#175776 Approved by: https://github.com/wconstab
Upstreaming from autoparallel: https://github.com/meta-pytorch/autoparallel/blob/454780d2a27456a380c0d8e997c8fc2cf82ef5d8/autoparallel/shardings/propagation_rules.py#L630 The previous strategy required full-Replicate: we can passthrough on non-padded dims, and allow Partial inputs when pad value = 0 (arguable if we should fix this). Rewritten as a single-dim strategy Pull Request resolved: pytorch#175656 Approved by: https://github.com/wconstab ghstack dependencies: pytorch#175776
Stack from ghstack (oldest at bottom):
aten.max/min.dim returns (values, indices), and strategies currently allow S(reduction_dim) -> P(max/min), P(max/min). This is invalid for indices, and we should ban sharding on the reduction dim. Rewrites as a single-dim strategy.