Skip to content

[DTensor] fix max.dim/min.dim strategy#175776

Closed
pianpwk wants to merge 5 commits intogh/pianpwk/102/basefrom
gh/pianpwk/102/head
Closed

[DTensor] fix max.dim/min.dim strategy#175776
pianpwk wants to merge 5 commits intogh/pianpwk/102/basefrom
gh/pianpwk/102/head

Conversation

@pianpwk
Copy link
Copy Markdown
Contributor

@pianpwk pianpwk commented Feb 25, 2026

Stack from ghstack (oldest at bottom):

aten.max/min.dim returns (values, indices), and strategies currently allow S(reduction_dim) -> P(max/min), P(max/min). This is invalid for indices, and we should ban sharding on the reduction dim. Rewrites as a single-dim strategy.

The previous strategy allowed Partial("max"/"min") on values when the
input was sharded on the reduction dim. While Partial is valid for
values, the indices are local to each shard and cannot be combined
across ranks — producing incorrect global indices.

Rewrite as a single_dim_strategy that only allows sharding on
non-reduction dims, forcing Replicate on the reduction dim so both
values and indices are computed correctly.

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Feb 25, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/175776

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (4 Unrelated Failures)

As of commit feea529 with merge base ea9fce2 (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pianpwk pianpwk changed the title [DTensor] Fix max.dim/min.dim strategy for correct indices [DTensor] fix max.dim/min.dim strategy Feb 25, 2026
@register_single_dim_strategy(
[aten.max.dim, aten.min.dim], schema_info=RuntimeSchemaInfo(1)
)
def max_min_dim_single_dim_strategy(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anshul-si i'm deferring to you on this, i think your PRs do not touch max.dim and min.dim so it is OK to land this first? Not sure if you were planning to work on the ops in _math_ops at some point?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can be landed first. i was planning on working on op in math_ops after pointwise_ops, but this can be used to help me

_ShardingPlaceholder(d),
]
)
return strategies
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aren't we forgetting some Partial prop from this rule?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought so too, but max(P(max)) -> P(max) values, but invalid indices, so we can't, regardless of reduction dimensions

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i see, we definitely can't return partial indices. i missed that indices was a return value despite your comment above. makes sense.

Copy link
Copy Markdown
Contributor

@wconstab wconstab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does the sharding validator run on this op + PR?

[ghstack-poisoned]
@pianpwk
Copy link
Copy Markdown
Contributor Author

pianpwk commented Feb 26, 2026

@wconstab I'm not sure about the result, but these seem to fall under the reduction_with_dim variant, and I don't see any missing rules for max/min.dim:

(pytorch-3048) [pianpwk@devvm3048.dkl0 /data/users/pianpwk/pytorch (d63a1ff5)]$ python -m torch.distributed.tensor._ops.strategy_validation --op max
Testing ops: aten.max
Device: cuda, Dtype: torch.float32, World size: 2

  OpInfo variant: reduction_with_dim

  OpInfo variant: reduction_no_dim

  OpInfo variant: binary

[1/1] aten.max — Samples: 14 (1 skipped), Combinations: 2907
----------------------------------------------------------------------

Possibly missing (valid in ground truth but no DTensor rule)

  [aten.max.default]
    P(max) -> P(max)

  [aten.maximum.default]
    P(max), R -> P(max)
    P(min), P(min) -> P(min)
    P(min), R -> P(min)
    R, P(avg) -> P(avg)
    R, P(max) -> P(max)
    R, P(min) -> P(min)
    R, P(sum) -> P(sum)

======================================================================
Summary
======================================================================
Op        Correct  Incorrect  Missing    Time
---------------------------------------------
aten.max       28          0        8    45.5s
---------------------------------------------
Total          28          0        8    45.5s






(pytorch-3048) [pianpwk@devvm3048.dkl0 /data/users/pianpwk/pytorch (381ee342)]$ python -m torch.distributed.tensor._ops.strategy_validation --op min
Testing ops: aten.min
Device: cuda, Dtype: torch.float32, World size: 2

  OpInfo variant: reduction_with_dim

  OpInfo variant: reduction_no_dim

  OpInfo variant: binary

[1/1] aten.min — Samples: 14 (1 skipped), Combinations: 2907
----------------------------------------------------------------------

Possibly missing (valid in ground truth but no DTensor rule)

  [aten.min.default]
    P(min) -> P(min)

  [aten.minimum.default]
    P(max), P(max) -> P(max)
    P(max), R -> P(max)
    P(min), R -> P(min)
    R, P(avg) -> P(avg)
    R, P(avg) -> R
    R, P(max) -> P(max)
    R, P(min) -> P(min)
    R, P(sum) -> P(avg)
    R, P(sum) -> R

======================================================================
Summary
======================================================================
Op        Correct  Incorrect  Missing    Time
---------------------------------------------
aten.min       28          0       10    45.4s
---------------------------------------------
Total          28          0       10    45.4s

Copy link
Copy Markdown
Contributor

@wconstab wconstab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, i was confused but your impl is right. no partials should be supported for max.dim variant.

@pianpwk
Copy link
Copy Markdown
Contributor Author

pianpwk commented Feb 26, 2026

@pytorchbot merge

@pytorch-bot pytorch-bot Bot added the ciflow/trunk Trigger trunk jobs on your pull request label Feb 26, 2026
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / macos-py3-arm64 / test (mps, 1, 1, macos-m1-14)

Details for Dev Infra team Raised by workflow job

@pianpwk
Copy link
Copy Markdown
Contributor Author

pianpwk commented Feb 26, 2026

@pytorchbot merge -i

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged while ignoring the following 2 checks: inductor / unit-test / inductor-test / test (inductor, 2, 2, linux.g5.4xlarge.nvidia.gpu), trunk / macos-py3-arm64 / test (mps, 1, 1, macos-m1-14)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@malfet
Copy link
Copy Markdown
Contributor

malfet commented Feb 27, 2026

@pytorchbot revert -m "Looks like introduced some new distributed breakages, see https://hud.pytorch.org/hud/pytorch/pytorch/1b9046a794cd2f8d882adf47d5612407cf43c1d2/1?per_page=50&name_filter=test%20(distr&mergeEphemeralLF=true" -c nosignal

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@malfet
Copy link
Copy Markdown
Contributor

malfet commented Mar 2, 2026

@pytorchbot revert -m "I'm not sure what's going on, but it breaks lint this time around, see https://hud.pytorch.org/hud/pytorch/pytorch/7c8edff72bcf501f2fb70a4b3149718a905c2471/1?per_page=50&name_filter=lint&mergeEphemeralLF=true" -c nosignal

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

pytorchmergebot added a commit that referenced this pull request Mar 2, 2026
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@pianpwk your PR has been successfully reverted.

@albanD
Copy link
Copy Markdown
Collaborator

albanD commented Mar 2, 2026

It's because of a land race with the PR that enables lint for plain assert.
The base of this PR is too old.

postmath pushed a commit to postmath/pytorch that referenced this pull request Mar 3, 2026
aten.max/min.dim returns (values, indices), and strategies currently allow S(reduction_dim) -> P(max/min), P(max/min). This is invalid for indices, and we should ban sharding on the reduction dim. Rewrites as a single-dim strategy.
Pull Request resolved: pytorch#175776
Approved by: https://github.com/wconstab
postmath pushed a commit to postmath/pytorch that referenced this pull request Mar 3, 2026
[ghstack-poisoned]
@pianpwk
Copy link
Copy Markdown
Contributor Author

pianpwk commented Mar 4, 2026

@pytorchbot merge

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit to anatoliylitv/pytorch that referenced this pull request Mar 4, 2026
aten.max/min.dim returns (values, indices), and strategies currently allow S(reduction_dim) -> P(max/min), P(max/min). This is invalid for indices, and we should ban sharding on the reduction dim. Rewrites as a single-dim strategy.
Pull Request resolved: pytorch#175776
Approved by: https://github.com/wconstab
pytorchmergebot pushed a commit that referenced this pull request Mar 5, 2026
Upstreaming from autoparallel: https://github.com/meta-pytorch/autoparallel/blob/454780d2a27456a380c0d8e997c8fc2cf82ef5d8/autoparallel/shardings/propagation_rules.py#L630

The previous strategy required full-Replicate: we can passthrough on non-padded dims, and allow Partial inputs when pad value = 0 (arguable if we should fix this). Rewritten as a single-dim strategy

Pull Request resolved: #175656
Approved by: https://github.com/wconstab
ghstack dependencies: #175776
Vighaneshs pushed a commit to Vighaneshs/pytorch that referenced this pull request Mar 5, 2026
Upstreaming from autoparallel: https://github.com/meta-pytorch/autoparallel/blob/454780d2a27456a380c0d8e997c8fc2cf82ef5d8/autoparallel/shardings/propagation_rules.py#L630

The previous strategy required full-Replicate: we can passthrough on non-padded dims, and allow Partial inputs when pad value = 0 (arguable if we should fix this). Rewritten as a single-dim strategy

Pull Request resolved: pytorch#175656
Approved by: https://github.com/wconstab
ghstack dependencies: pytorch#175776
sandy-gags pushed a commit to sandy-gags/pytorch that referenced this pull request Mar 12, 2026
The previous strategy allowed Partial("max"/"min") on values when the
input was sharded on the reduction dim. While Partial is valid for
values, the indices are local to each shard and cannot be combined
across ranks — producing incorrect global indices.

Rewrite as a single_dim_strategy that only allows sharding on
non-reduction dims, forcing Replicate on the reduction dim so both
values and indices are computed correctly.

ghstack-source-id: b7b09ff
Pull Request resolved: pytorch/pytorch#175776
EmanueleCoradin pushed a commit to EmanueleCoradin/pytorch that referenced this pull request Mar 30, 2026
aten.max/min.dim returns (values, indices), and strategies currently allow S(reduction_dim) -> P(max/min), P(max/min). This is invalid for indices, and we should ban sharding on the reduction dim. Rewrites as a single-dim strategy.
Pull Request resolved: pytorch#175776
Approved by: https://github.com/wconstab
EmanueleCoradin pushed a commit to EmanueleCoradin/pytorch that referenced this pull request Mar 30, 2026
aten.max/min.dim returns (values, indices), and strategies currently allow S(reduction_dim) -> P(max/min), P(max/min). This is invalid for indices, and we should ban sharding on the reduction dim. Rewrites as a single-dim strategy.
Pull Request resolved: pytorch#175776
Approved by: https://github.com/wconstab
EmanueleCoradin pushed a commit to EmanueleCoradin/pytorch that referenced this pull request Mar 30, 2026
EmanueleCoradin pushed a commit to EmanueleCoradin/pytorch that referenced this pull request Mar 30, 2026
aten.max/min.dim returns (values, indices), and strategies currently allow S(reduction_dim) -> P(max/min), P(max/min). This is invalid for indices, and we should ban sharding on the reduction dim. Rewrites as a single-dim strategy.
Pull Request resolved: pytorch#175776
Approved by: https://github.com/wconstab
EmanueleCoradin pushed a commit to EmanueleCoradin/pytorch that referenced this pull request Mar 30, 2026
Upstreaming from autoparallel: https://github.com/meta-pytorch/autoparallel/blob/454780d2a27456a380c0d8e997c8fc2cf82ef5d8/autoparallel/shardings/propagation_rules.py#L630

The previous strategy required full-Replicate: we can passthrough on non-padded dims, and allow Partial inputs when pad value = 0 (arguable if we should fix this). Rewritten as a single-dim strategy

Pull Request resolved: pytorch#175656
Approved by: https://github.com/wconstab
ghstack dependencies: pytorch#175776
@github-actions github-actions Bot deleted the gh/pianpwk/102/head branch April 4, 2026 02:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: distributed (dtensor) release notes category Reverted

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants