Skip to content

[DTensor] Add sharding strategy for aten.squeeze.dims#173563

Closed
stmcgovern wants to merge 8 commits intopytorch:mainfrom
stmcgovern:173521-squeeze-dims
Closed

[DTensor] Add sharding strategy for aten.squeeze.dims#173563
stmcgovern wants to merge 8 commits intopytorch:mainfrom
stmcgovern:173521-squeeze-dims

Conversation

@stmcgovern
Copy link
Copy Markdown
Collaborator

@stmcgovern stmcgovern commented Jan 27, 2026

Fixes #173521
Fixes #166124
Extend dim_squeeze to handle multiple dimensions by normalizing all dim variants to a target dimension set. This unifies the logic into a single code path.
Fix the long-standing FIXME in dim_squeeze where squeeze(dim=None) could incorrectly remove sharded dimensions whose local size happened to be 1 (despite global size > 1). Canonicalizes all squeeze variants to squeeze.dims at the sharding propagation level using global shape to determine which dimensions are truly singleton.

Strategy validator: 74 correct, 0 incorrect, 0 missing. This is without the P(max/min) - R rules mentioned below.

  • Add test_squeeze_variants to test all squeeze variants with DTensor

Note: op_db test remains xfail due to pre-existing bug where local squeeze removes sharded dims with local size 1 (see PR #166862). That PR is/will be closed in favor of this approach that avoids a custom handler

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Jan 27, 2026

This PR needs a release notes: label

If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@pytorch-bot pytorch-bot Bot added the release notes: distributed (dtensor) release notes category label Jan 27, 2026
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Jan 27, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/173563

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit eeac88f with merge base d428a3f (image):

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

  • Lint OSDC (unstable) / lintrunner-noclang-all / lint (gh)
    Error computing the main repository mapping: Encountered error while reading extension file 'requirements.bzl': no such package '@pip_deps//': no such package '@python3_10_x86_64-unknown-linux-gnu//': The current user is root, please run as non-root when using the hermetic Python interpreter. See https://github.com/bazelbuild/rules_python/pull/713.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@soulitzer soulitzer requested review from XilunWu and wconstab January 29, 2026 15:10
@soulitzer soulitzer self-assigned this Jan 29, 2026
@soulitzer soulitzer added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jan 29, 2026
@wconstab
Copy link
Copy Markdown
Contributor

looks correct to me. though, there are a couple trivial rules missing for P(max) -> R or P(min) -> R when we squeeze away a size-1 partial, can you also fill that in?

@soulitzer soulitzer removed their assignment Jan 31, 2026
@stmcgovern stmcgovern marked this pull request as draft February 3, 2026 15:51
@stmcgovern stmcgovern marked this pull request as ready for review February 3, 2026 20:55
Copy link
Copy Markdown
Contributor

@wconstab wconstab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think i can be convinced that it's a good plan to remap all the squeeze variants to .dims when we need to normalize dims to avoid nonzero dim cases. but it'd help me if you updated the PR desc explaining this a bit more clearly. Also, i didn't understand if you're fixing the FIXME before landing or if you're not able to for some reason.

@stmcgovern stmcgovern force-pushed the 173521-squeeze-dims branch 3 times, most recently from 988b42f to 9663d09 Compare March 5, 2026 20:59
@stmcgovern stmcgovern force-pushed the 173521-squeeze-dims branch from 9663d09 to 0fab03a Compare March 5, 2026 22:21
@stmcgovern stmcgovern force-pushed the 173521-squeeze-dims branch 4 times, most recently from 1787361 to 6801a99 Compare March 9, 2026 15:52
@stmcgovern
Copy link
Copy Markdown
Collaborator Author

there are a couple trivial rules missing for P(max) -> R or P(min) -> R when we squeeze away a size-1 partial

I looked into adding this, but I don't really understand why its needed by squeeze. The (recent) strategy validator does not expect/ask for it. was it showing as missing in earlier strategy validator implementations? @wconstab

@wconstab
Copy link
Copy Markdown
Contributor

I think those P->R rules should be invalid, probably the validator is better now. @claude can you evaluate this PR overall and specifically say whether it has the complete and correct set of sharding rules for squeeze?

@claude
Copy link
Copy Markdown

claude Bot commented Mar 10, 2026

Claude finished @wconstab's task in 4m 41s —— View job


Evaluation: Squeeze Sharding Rules

  • Read PR diff and all changed files
  • Understand the existing squeeze sharding logic
  • Analyze the new sharding rules for correctness and completeness
  • Post evaluation

Overview

This PR introduces a two-level approach:

  1. _adjust_squeeze_to_global_singletons (_sharding_prop.py:972-1014) rewrites all squeeze variants to squeeze.dims with only globally-singleton dims before the DimMap strategy runs.
  2. dim_squeeze (_view_ops.py:468-484) then operates on local shapes, producing a DimMap that propagate_shape_and_sharding uses to derive output placements.

Sharding Rules — Complete and Correct

The rules produced by this pipeline cover all placement types:

Input Placement Squeezed dim? Output Placement Why
Replicate() any Replicate() Trivially preserved
Shard(d) d not squeezed Shard(d') (shifted index) InputDim(d) stays in DimMap; shard_dim_map remaps to new position
Shard(d) d squeezed, global size=1 Replicate() All devices have identical size-1 slice; dim removed from DimMap → shardable_dims[d]=[False,...] → Shard→Replicate
Shard(d) local size=1, global size>1 Shard(d') (preserved) The FIXME fix: _adjust_squeeze_to_global_singletons uses global shape, so this dim is NOT in target_dims; dim survives
Partial(op) any Partial(op) Partial is a mesh-level reduction property, not a tensor-dim property; passes through unchanged at _view_ops.py:735-738
_StridedShard(d) same as Shard same as Shard (preserves split_factor) Handled by same isinstance(Shard | _StridedShard) branches

P(max/min) → R is NOT needed

I agree with the author's assessment. The validator is correct to not expect Partial(max)Replicate rules for squeeze. Partial is a per-mesh-dimension property indicating that each device holds a partial result requiring reduction. Squeezing a tensor dimension doesn't change whether the data needs reduction across mesh dimensions. A Partial("max") stays Partial("max") regardless of what dims are squeezed — this is handled correctly by the pass-through logic.

Op variant coverage

All 6 ATen squeeze variants are covered:

  • aten.squeeze.default, aten.squeeze.dim, aten.squeeze.dims
  • aten.squeeze_.default, aten.squeeze_.dim, aten.squeeze_.dims

squeeze_copy variants don't need explicit handling — they decompose to squeeze + clone before reaching DTensor dispatch (confirmed in torch/_decomp/__init__.py:507).

Symbolic shape handling

The implementation is conservative with symbolic shapes: guard_or_false(global_shape[d] == 1) returns False for unknown sizes, meaning symbolic dimensions are never treated as singleton. This is the safe default.

One issue worth noting

The inplace squeeze_ special case in _dispatch.py:489:

if op_call == aten.squeeze_.dim:
    args[0]._spec = output_spec
    return return_and_correct_aliasing(op_call, args, kwargs, args[0])

This only matches squeeze_.dim, not squeeze_.default or squeeze_.dims. Since the PR now registers squeeze_.default and the effective local op becomes squeeze_.dims, the inplace special case at line 489 won't trigger for those variants. When a squeeze_.default call results in a placement change (e.g., Shard(0) on a globally-singleton dim → Replicate), the else branch at line 500 would raise a RuntimeError about placement changes not being supported.

In practice this is a narrow edge case (you'd need to call squeeze_.default or squeeze_.dims directly on a DTensor with Shard(d) where d has global size 1). But it would be more robust to update line 489 to:

if op_call in (aten.squeeze_.dim, aten.squeeze_.default, aten.squeeze_.dims):

Verdict

The sharding rules are complete and correct for all squeeze variants. The DimMap + global-singleton rewrite approach is sound and properly handles the previously-broken FIXME case. The test coverage (test_squeeze_variants) is thorough, covering the key scenarios including the sharded-dim-local-size-1 bug.


@stmcgovern stmcgovern force-pushed the 173521-squeeze-dims branch from f504b95 to 0d4fcec Compare March 10, 2026 14:21
@stmcgovern
Copy link
Copy Markdown
Collaborator Author

In practice this is a narrow edge case (you'd need to call squeeze_.default or squeeze_.dims directly on a DTensor with Shard(d) where d has global size 1). But it would be more robust to update line 489 to:

if op_call in (aten.squeeze_.dim, aten.squeeze_.default, aten.squeeze_.dims):

Added this from claude review and fixed CI failures (unexpected batch norm successes).

@stmcgovern
Copy link
Copy Markdown
Collaborator Author

I think it makes sense to bug fix an existing issue with #175798 before the strategy changes in this PR here.

Copy link
Copy Markdown
Contributor

@wconstab wconstab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, thanks! i'll take a look at the other PR you referenced now

@stmcgovern stmcgovern force-pushed the 173521-squeeze-dims branch 3 times, most recently from 5304570 to e203dc1 Compare March 11, 2026 12:51
… FIXME

Fix the long-standing FIXME in dim_squeeze where squeeze with dim=None
could incorrectly remove sharded dimensions whose local size happened
to be 1 (despite global size > 1). The fix canonicalizes all squeeze
variants to squeeze.dims at the sharding propagation level, using
global shape to determine which dimensions are truly singleton.

The core change is _adjust_squeeze_to_global_singletons in sharding_prop,
which rewrites the op before dim_squeeze ever sees it. The dispatch layer
uses an effective_op pattern to execute the rewritten op and appends any
extra args (like the dims tuple) from the rewritten schema. Also registers
the missing aten.squeeze_.default strategy.
@stmcgovern stmcgovern force-pushed the 173521-squeeze-dims branch from e203dc1 to 4393ffb Compare March 20, 2026 17:45
batch_norm ops (_native_batch_norm_legit, native_batch_norm,
nn.functional.batch_norm) now pass in eager, local, multi-threaded,
and compiled modes thanks to squeeze.dims unblocking their
decomposition path. They only fail in unbacked tests (torch.compile
+ mark_unbacked) due to symbolic shape limitations in
DecompShardingStrategy. Move xfails from dtensor_fails_no_strategy
to ops_unbacked_dtensor_dde.

squeeze_copy errors in local crossref (shape mismatch) but passes
in multi-threaded — xfail it inline for TestLocalDTensorOps only.
@stmcgovern stmcgovern force-pushed the 173521-squeeze-dims branch from 4393ffb to cfb8978 Compare March 20, 2026 18:05
The squeeze.dims strategy registration enables DecompShardingStrategy
to propagate sharding through op decompositions that internally use
squeeze.dims. Previously these ops failed because squeeze.dims had
no sharding strategy, blocking the entire decomposition path.

Remove 10 ops from dtensor_fails_no_strategy: nansum,
adaptive_avg_pool1d/2d/3d, adaptive_max_pool1d/2d/3d,
avg_pool1d/2d/3d.
Restore squeeze_copy xfail in dtensor_fails_no_strategy (no strategy
exists for squeeze_copy). Gate effective_op behind needs_redistribute
and use self._squeeze_inplace_ops set to avoid hot-path overhead. Add
CommDebugMode assertions verifying squeeze ops are communication-free.
@stmcgovern
Copy link
Copy Markdown
Collaborator Author

stmcgovern commented Mar 23, 2026

I think it makes sense to bug fix an existing issue with #175798 before the strategy changes in this PR here.

Since that PR needs more work, I think landing this current one is the first step. The CI now passes after multiple rounds of unexpected successes and getting the correct tests to skip/run for ops.

@stmcgovern
Copy link
Copy Markdown
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot Bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 23, 2026
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

AaronWang04 pushed a commit to AaronWang04/pytorch that referenced this pull request Mar 31, 2026
Fixes pytorch#173521
Fixes pytorch#166124
Extend `dim_squeeze` to handle multiple dimensions by normalizing all dim variants to a target dimension set. This unifies the logic into a single code path.
Fix the long-standing FIXME in dim_squeeze where squeeze(dim=None) could incorrectly remove sharded dimensions whose local size happened to be 1 (despite global size > 1). Canonicalizes all squeeze variants to squeeze.dims at the sharding propagation level using global shape to determine which dimensions are truly singleton.

Strategy validator: 74 correct, 0 incorrect, 0 missing. This is without the P(max/min) - R rules mentioned below.

- Add test_squeeze_variants to test all squeeze variants with DTensor

~~Note: op_db test remains xfail due to pre-existing bug where local squeeze removes sharded dims with local size 1 (see PR pytorch#166862).~~ That PR is/will be closed in favor of this approach that avoids a custom handler

Pull Request resolved: pytorch#173563
Approved by: https://github.com/wconstab
nklshy-aws pushed a commit to nklshy-aws/pytorch that referenced this pull request Apr 7, 2026
Fixes pytorch#173521
Fixes pytorch#166124
Extend `dim_squeeze` to handle multiple dimensions by normalizing all dim variants to a target dimension set. This unifies the logic into a single code path.
Fix the long-standing FIXME in dim_squeeze where squeeze(dim=None) could incorrectly remove sharded dimensions whose local size happened to be 1 (despite global size > 1). Canonicalizes all squeeze variants to squeeze.dims at the sharding propagation level using global shape to determine which dimensions are truly singleton.

Strategy validator: 74 correct, 0 incorrect, 0 missing. This is without the P(max/min) - R rules mentioned below.

- Add test_squeeze_variants to test all squeeze variants with DTensor

~~Note: op_db test remains xfail due to pre-existing bug where local squeeze removes sharded dims with local size 1 (see PR pytorch#166862).~~ That PR is/will be closed in favor of this approach that avoids a custom handler

Pull Request resolved: pytorch#173563
Approved by: https://github.com/wconstab
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: distributed (dtensor) release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[DTensor] aten.squeeze.dims rules not implemented [DTensor] squeeze() causes incorrect dimension when sharded dimension size equals mesh size.

5 participants