Skip to content

[DTensor] Fix squeeze() removing non-singleton sharded dimensions.#166862

Closed
mansiag05 wants to merge 1 commit intopytorch:mainfrom
mansiag05:fix-issue-166124
Closed

[DTensor] Fix squeeze() removing non-singleton sharded dimensions.#166862
mansiag05 wants to merge 1 commit intopytorch:mainfrom
mansiag05:fix-issue-166124

Conversation

@mansiag05
Copy link
Copy Markdown
Collaborator

@mansiag05 mansiag05 commented Nov 3, 2025

Fix bug where squeeze() incorrectly removes dimensions that are locally singleton (size=1) but globally not (size=mesh_size). Added custom handler to check global shape before squeezing local tensor.

Fixes #166124

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci @aditvenk

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Nov 3, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/166862

Note: Links to docs will display an error until the docs builds have been completed.

❌ 4 New Failures

As of commit 105160b with merge base 0cd681d (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot Bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Nov 3, 2025
@mansiag05
Copy link
Copy Markdown
Collaborator Author

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot Bot added the topic: not user facing topic category label Nov 3, 2025
@mansiag05
Copy link
Copy Markdown
Collaborator Author

cc @stmcgovern @skpark-rh

@skpark-rh
Copy link
Copy Markdown
Collaborator

I don't think there is a bug in the squeeze. I actually think it is working as intended. I do think there is an issue with the unsqueeze logic. If you squeeze the dimensions, it should modify the tensor and remove all 1 in the all shapes. When I do an unsqueeze(0), I do see that the full tensor becomes (1, 32) instead of (1, 4, 8). My intuition tells me that either there is a misunderstanding about the API or it is not working as intended. I am leaning on the first.

@skpark-rh
Copy link
Copy Markdown
Collaborator

At the very least I think the test for squeeze that is skipped needs to be removed and utilized for this fix. (test/distributed/tensor/test_dtensor_ops.py:507)

@janeyx99 janeyx99 requested a review from XilunWu November 7, 2025 16:50
@janeyx99 janeyx99 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Nov 7, 2025
@mansiag05
Copy link
Copy Markdown
Collaborator Author

Thanks for taking a look at this @skpark-rh. I totally get where you're coming from, the global shape reporting does look correct at first glance. Let me show you what's actually going wrong under the hood.

The dtensor output shape correctly reports [4, 8] after squeeze. But the sneaky part is that the local tensors on each rank are getting incorrectly squeezed, and that breaks everything downstream.
The problem is that each rank's local tensor went from [1, 8][8], but the DTensor still thinks it's managing a [4, 8] tensor with Shard(0). When you try to gather them back with full_tensor(), it crashes because the shapes don't line up.

Also, There's actually a FIXME comment in the code at /torch/distributed/tensor/_ops/_view_ops.py:431-440 that describes this exact issue. So this was a known issue that just hadn't been fixed yet!

And Great catch on line 507! You're absolutely right - there's a skipped test. Looks like it was skipped because of this bug! The OpInfo tests were failing because full_tensor() would crash after squeeze.
Since the PR fixes it so we can finally enable the skipped tests

Does this make sense? Happy to clarify anything! 😊

@skpark-rh
Copy link
Copy Markdown
Collaborator

I see. So when the global tensor is [4, 8, 1], the dtensor when sharding(0) would create local tensors of [1, 8, 1]. The squeeze should only remove the 3rd dim and keep the first dim.

Comment thread torch/distributed/tensor/_dispatch.py Outdated
@github-actions
Copy link
Copy Markdown
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions Bot added the Stale label Jan 10, 2026
Copy link
Copy Markdown
Contributor

@wconstab wconstab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this looks correct. if we can not find a less nuclear way to support squeeze than using a custom handler, i'll spend more time looking at whether we covered all the cases in this handler. But i wanted to let others give suggestions.

one thought, which might not be a good idea, is to make a more minimal special case for sqeeze in the dispatch path that would just mutate the args/kwargs (noop when 'dim' arg is present, but computes a new 'dim' arg based on global singleton dims when dim is not present) and then do the rest of dispatch the normal way. it might be strictly better to just use the whole override approach as in this PR.

Copy link
Copy Markdown
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of adding a special handler in dispatch.py, I wonder if it's better to adjust the arg in sharding_prop.py.

We have a few ops (view, new_empty, etc.) for which we have to adjust the shape, and still use the per-op strategies.
https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/_sharding_prop.py#L155
Here it's similar that

  • we can keep the strategy
  • and only modifies non-Tensor arg (the dim)

Comment thread torch/distributed/tensor/_dispatch.py Outdated
…torch#166124)

Fix bug where squeeze() incorrectly removes dimensions that are locally
singleton (size=1) but globally not (size=mesh_size). Added custom handler
to check global shape before squeezing local tensor.
@pytorch-bot pytorch-bot Bot added the release notes: distributed (dtensor) release notes category label Jan 27, 2026
@mansiag05
Copy link
Copy Markdown
Collaborator Author

Hello @tianyu-l,

I looked into the op_to_shape_and_stride_idx, however, I'm not sure how to apply that pattern here. For view and new_empty, we modify the shape arg but keep the same op. For squeeze, the challenge is:

  • squeeze.default has no dim arg - it squeezes all local singleton dims
  • To squeeze only globally singleton dims, we need to use squeeze.dims with explicit dims

This means we need to change the op variant, not just modify the arg. The current redistribute_schema pattern only modifies args while op_call stays the same.

Could you help me understand how this can be approached?

stmcgovern added a commit to stmcgovern/pytorch that referenced this pull request Jan 27, 2026
Extend `dim_squeeze` to handle multiple dimensions by normalizing all
dim variants to a target dimension set. This unifies the logic into a
single code path.

- Extend `dim_squeeze()` type signature to `DimsType | None`
- Normalize all cases to `target_dims: set[int]`
- Single return path: keep dims that are size > 1 or not targeted
- Register `aten.squeeze.dims` using existing torch.squeeze mapping
- Add test_squeeze_variants to test all squeeze variants with DTensor

Note: op_db test remains xfail due to pre-existing bug where local
squeeze removes sharded dims with local size 1 (see PR pytorch#166862).

Fixes pytorch#173521
@tianyu-l tianyu-l requested a review from pianpwk February 2, 2026 08:35
@tianyu-l
Copy link
Copy Markdown
Contributor

tianyu-l commented Feb 2, 2026

This means we need to change the op variant, not just modify the arg. The current redistribute_schema pattern only modifies args while op_call stays the same.

What I meant was not to reuse code for op_to_shape_and_stride_idx; instead we could invent new functions in sharding prop to achieve what squeeze handler does.

Given your squeeze_handler is for squeeze ops only for now, and the handler complexity is limited, I think the PR is acceptable. cc @pianpwk to review too

@wconstab
Copy link
Copy Markdown
Contributor

wconstab commented Feb 2, 2026

@stmcgovern you have another PR for squeeze.dims - do you mind aligning with @mansiag05 on an overall approach?

I'm ok with this PR- let's get it cleaned up and land-ready, then i'll review

@stmcgovern
Copy link
Copy Markdown
Collaborator

@wconstab Thanks. I opened #173563 thinking that it could be orthogonal to this PR. Looking at this a bit more and following the interesting discussion here, I do think that we can avoid adding a custom handler here if we rewrite the squeeze to squeeze.dims in _sharding_prop.py. I'll investigate a bit more and coordinate with @mansiag05 .

dim_normalized = dim_arg if dim_arg >= 0 else dim_arg + len(global_shape)
singleton_dims = (dim_normalized,) if global_shape[dim_normalized] == 1 else ()
else:
singleton_dims = tuple(i for i, size in enumerate(global_shape) if size == 1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just wondering: is it possible to construct a test case where mesh_dim_size > 1, but tensor_dim_size < mesh_dim_size? e.g. shard size 4 on mesh dim with 8 ranks.

I'm wondering if some ranks would see size [0, ...], and this squeeze logic would not work.

stmcgovern added a commit to stmcgovern/pytorch that referenced this pull request Feb 3, 2026
Extend `dim_squeeze` to handle multiple dimensions by normalizing all
dim variants to a target dimension set. This unifies the logic into a
single code path.

- Extend `dim_squeeze()` type signature to `DimsType | None`
- Normalize all cases to `target_dims: set[int]`
- Single return path: keep dims that are size > 1 or not targeted
- Register `aten.squeeze.dims` using existing torch.squeeze mapping
- Add test_squeeze_variants to test all squeeze variants with DTensor

Note: op_db test remains xfail due to pre-existing bug where local
squeeze removes sharded dims with local size 1 (see PR pytorch#166862).

Fixes pytorch#173521
@stmcgovern
Copy link
Copy Markdown
Collaborator

I updated #173563 to include the fix for the local/global singleton mismatch FIXME (superseding this PR) . It leverages the squeeze.dims strategy support to turn all squeeze op variants into squeeze.dims in _sharding_prop.py and then handle the dims in one place. It avoids the custom handler, but touches _dispatch.py @wconstab @tianyu-l @pianpwk @mansiag05

@wconstab
Copy link
Copy Markdown
Contributor

wconstab commented Feb 6, 2026

iiuc this PR is no longer needed in favor of #173563? (can you close this one if so, or clarify)

@github-actions github-actions Bot closed this Mar 8, 2026
pytorchmergebot pushed a commit that referenced this pull request Mar 23, 2026
Fixes #173521
Fixes #166124
Extend `dim_squeeze` to handle multiple dimensions by normalizing all dim variants to a target dimension set. This unifies the logic into a single code path.
Fix the long-standing FIXME in dim_squeeze where squeeze(dim=None) could incorrectly remove sharded dimensions whose local size happened to be 1 (despite global size > 1). Canonicalizes all squeeze variants to squeeze.dims at the sharding propagation level using global shape to determine which dimensions are truly singleton.

Strategy validator: 74 correct, 0 incorrect, 0 missing. This is without the P(max/min) - R rules mentioned below.

- Add test_squeeze_variants to test all squeeze variants with DTensor

~~Note: op_db test remains xfail due to pre-existing bug where local squeeze removes sharded dims with local size 1 (see PR #166862).~~ That PR is/will be closed in favor of this approach that avoids a custom handler

Pull Request resolved: #173563
Approved by: https://github.com/wconstab
AaronWang04 pushed a commit to AaronWang04/pytorch that referenced this pull request Mar 31, 2026
Fixes pytorch#173521
Fixes pytorch#166124
Extend `dim_squeeze` to handle multiple dimensions by normalizing all dim variants to a target dimension set. This unifies the logic into a single code path.
Fix the long-standing FIXME in dim_squeeze where squeeze(dim=None) could incorrectly remove sharded dimensions whose local size happened to be 1 (despite global size > 1). Canonicalizes all squeeze variants to squeeze.dims at the sharding propagation level using global shape to determine which dimensions are truly singleton.

Strategy validator: 74 correct, 0 incorrect, 0 missing. This is without the P(max/min) - R rules mentioned below.

- Add test_squeeze_variants to test all squeeze variants with DTensor

~~Note: op_db test remains xfail due to pre-existing bug where local squeeze removes sharded dims with local size 1 (see PR pytorch#166862).~~ That PR is/will be closed in favor of this approach that avoids a custom handler

Pull Request resolved: pytorch#173563
Approved by: https://github.com/wconstab
nklshy-aws pushed a commit to nklshy-aws/pytorch that referenced this pull request Apr 7, 2026
Fixes pytorch#173521
Fixes pytorch#166124
Extend `dim_squeeze` to handle multiple dimensions by normalizing all dim variants to a target dimension set. This unifies the logic into a single code path.
Fix the long-standing FIXME in dim_squeeze where squeeze(dim=None) could incorrectly remove sharded dimensions whose local size happened to be 1 (despite global size > 1). Canonicalizes all squeeze variants to squeeze.dims at the sharding propagation level using global shape to determine which dimensions are truly singleton.

Strategy validator: 74 correct, 0 incorrect, 0 missing. This is without the P(max/min) - R rules mentioned below.

- Add test_squeeze_variants to test all squeeze variants with DTensor

~~Note: op_db test remains xfail due to pre-existing bug where local squeeze removes sharded dims with local size 1 (see PR pytorch#166862).~~ That PR is/will be closed in favor of this approach that avoids a custom handler

Pull Request resolved: pytorch#173563
Approved by: https://github.com/wconstab
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

oncall: distributed Add this issue/PR to distributed oncall triage queue open source release notes: distributed (dtensor) release notes category Stale topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[DTensor] squeeze() causes incorrect dimension when sharded dimension size equals mesh size.

8 participants