[DTensor] Fix squeeze() removing non-singleton sharded dimensions.#166862
[DTensor] Fix squeeze() removing non-singleton sharded dimensions.#166862mansiag05 wants to merge 1 commit intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/166862
Note: Links to docs will display an error until the docs builds have been completed. ❌ 4 New FailuresAs of commit 105160b with merge base 0cd681d ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@pytorchbot label "topic: not user facing" |
|
I don't think there is a bug in the squeeze. I actually think it is working as intended. I do think there is an issue with the unsqueeze logic. If you squeeze the dimensions, it should modify the tensor and remove all 1 in the all shapes. When I do an unsqueeze(0), I do see that the full tensor becomes (1, 32) instead of (1, 4, 8). My intuition tells me that either there is a misunderstanding about the API or it is not working as intended. I am leaning on the first. |
|
At the very least I think the test for squeeze that is skipped needs to be removed and utilized for this fix. ( |
|
Thanks for taking a look at this @skpark-rh. I totally get where you're coming from, the global shape reporting does look correct at first glance. Let me show you what's actually going wrong under the hood. The dtensor output shape correctly reports Also, There's actually a FIXME comment in the code at And Great catch on line 507! You're absolutely right - there's a skipped test. Looks like it was skipped because of this bug! The OpInfo tests were failing because Does this make sense? Happy to clarify anything! 😊 |
|
I see. So when the global tensor is [4, 8, 1], the dtensor when sharding(0) would create local tensors of [1, 8, 1]. The squeeze should only remove the 3rd dim and keep the first dim. |
|
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
wconstab
left a comment
There was a problem hiding this comment.
i think this looks correct. if we can not find a less nuclear way to support squeeze than using a custom handler, i'll spend more time looking at whether we covered all the cases in this handler. But i wanted to let others give suggestions.
one thought, which might not be a good idea, is to make a more minimal special case for sqeeze in the dispatch path that would just mutate the args/kwargs (noop when 'dim' arg is present, but computes a new 'dim' arg based on global singleton dims when dim is not present) and then do the rest of dispatch the normal way. it might be strictly better to just use the whole override approach as in this PR.
tianyu-l
left a comment
There was a problem hiding this comment.
Instead of adding a special handler in dispatch.py, I wonder if it's better to adjust the arg in sharding_prop.py.
We have a few ops (view, new_empty, etc.) for which we have to adjust the shape, and still use the per-op strategies.
https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/_sharding_prop.py#L155
Here it's similar that
- we can keep the strategy
- and only modifies non-Tensor arg (the
dim)
…torch#166124) Fix bug where squeeze() incorrectly removes dimensions that are locally singleton (size=1) but globally not (size=mesh_size). Added custom handler to check global shape before squeezing local tensor.
63674b9 to
105160b
Compare
|
Hello @tianyu-l, I looked into the
This means we need to change the op variant, not just modify the arg. The current redistribute_schema pattern only modifies args while op_call stays the same. Could you help me understand how this can be approached? |
Extend `dim_squeeze` to handle multiple dimensions by normalizing all dim variants to a target dimension set. This unifies the logic into a single code path. - Extend `dim_squeeze()` type signature to `DimsType | None` - Normalize all cases to `target_dims: set[int]` - Single return path: keep dims that are size > 1 or not targeted - Register `aten.squeeze.dims` using existing torch.squeeze mapping - Add test_squeeze_variants to test all squeeze variants with DTensor Note: op_db test remains xfail due to pre-existing bug where local squeeze removes sharded dims with local size 1 (see PR pytorch#166862). Fixes pytorch#173521
What I meant was not to reuse code for Given your |
|
@stmcgovern you have another PR for squeeze.dims - do you mind aligning with @mansiag05 on an overall approach? I'm ok with this PR- let's get it cleaned up and land-ready, then i'll review |
|
@wconstab Thanks. I opened #173563 thinking that it could be orthogonal to this PR. Looking at this a bit more and following the interesting discussion here, I do think that we can avoid adding a custom handler here if we rewrite the squeeze to squeeze.dims in _sharding_prop.py. I'll investigate a bit more and coordinate with @mansiag05 . |
| dim_normalized = dim_arg if dim_arg >= 0 else dim_arg + len(global_shape) | ||
| singleton_dims = (dim_normalized,) if global_shape[dim_normalized] == 1 else () | ||
| else: | ||
| singleton_dims = tuple(i for i, size in enumerate(global_shape) if size == 1) |
There was a problem hiding this comment.
just wondering: is it possible to construct a test case where mesh_dim_size > 1, but tensor_dim_size < mesh_dim_size? e.g. shard size 4 on mesh dim with 8 ranks.
I'm wondering if some ranks would see size [0, ...], and this squeeze logic would not work.
Extend `dim_squeeze` to handle multiple dimensions by normalizing all dim variants to a target dimension set. This unifies the logic into a single code path. - Extend `dim_squeeze()` type signature to `DimsType | None` - Normalize all cases to `target_dims: set[int]` - Single return path: keep dims that are size > 1 or not targeted - Register `aten.squeeze.dims` using existing torch.squeeze mapping - Add test_squeeze_variants to test all squeeze variants with DTensor Note: op_db test remains xfail due to pre-existing bug where local squeeze removes sharded dims with local size 1 (see PR pytorch#166862). Fixes pytorch#173521
|
I updated #173563 to include the fix for the local/global singleton mismatch FIXME (superseding this PR) . It leverages the squeeze.dims strategy support to turn all squeeze op variants into squeeze.dims in _sharding_prop.py and then handle the dims in one place. It avoids the custom handler, but touches _dispatch.py @wconstab @tianyu-l @pianpwk @mansiag05 |
|
iiuc this PR is no longer needed in favor of #173563? (can you close this one if so, or clarify) |
Fixes #173521 Fixes #166124 Extend `dim_squeeze` to handle multiple dimensions by normalizing all dim variants to a target dimension set. This unifies the logic into a single code path. Fix the long-standing FIXME in dim_squeeze where squeeze(dim=None) could incorrectly remove sharded dimensions whose local size happened to be 1 (despite global size > 1). Canonicalizes all squeeze variants to squeeze.dims at the sharding propagation level using global shape to determine which dimensions are truly singleton. Strategy validator: 74 correct, 0 incorrect, 0 missing. This is without the P(max/min) - R rules mentioned below. - Add test_squeeze_variants to test all squeeze variants with DTensor ~~Note: op_db test remains xfail due to pre-existing bug where local squeeze removes sharded dims with local size 1 (see PR #166862).~~ That PR is/will be closed in favor of this approach that avoids a custom handler Pull Request resolved: #173563 Approved by: https://github.com/wconstab
Fixes pytorch#173521 Fixes pytorch#166124 Extend `dim_squeeze` to handle multiple dimensions by normalizing all dim variants to a target dimension set. This unifies the logic into a single code path. Fix the long-standing FIXME in dim_squeeze where squeeze(dim=None) could incorrectly remove sharded dimensions whose local size happened to be 1 (despite global size > 1). Canonicalizes all squeeze variants to squeeze.dims at the sharding propagation level using global shape to determine which dimensions are truly singleton. Strategy validator: 74 correct, 0 incorrect, 0 missing. This is without the P(max/min) - R rules mentioned below. - Add test_squeeze_variants to test all squeeze variants with DTensor ~~Note: op_db test remains xfail due to pre-existing bug where local squeeze removes sharded dims with local size 1 (see PR pytorch#166862).~~ That PR is/will be closed in favor of this approach that avoids a custom handler Pull Request resolved: pytorch#173563 Approved by: https://github.com/wconstab
Fixes pytorch#173521 Fixes pytorch#166124 Extend `dim_squeeze` to handle multiple dimensions by normalizing all dim variants to a target dimension set. This unifies the logic into a single code path. Fix the long-standing FIXME in dim_squeeze where squeeze(dim=None) could incorrectly remove sharded dimensions whose local size happened to be 1 (despite global size > 1). Canonicalizes all squeeze variants to squeeze.dims at the sharding propagation level using global shape to determine which dimensions are truly singleton. Strategy validator: 74 correct, 0 incorrect, 0 missing. This is without the P(max/min) - R rules mentioned below. - Add test_squeeze_variants to test all squeeze variants with DTensor ~~Note: op_db test remains xfail due to pre-existing bug where local squeeze removes sharded dims with local size 1 (see PR pytorch#166862).~~ That PR is/will be closed in favor of this approach that avoids a custom handler Pull Request resolved: pytorch#173563 Approved by: https://github.com/wconstab
Fix bug where squeeze() incorrectly removes dimensions that are locally singleton (size=1) but globally not (size=mesh_size). Added custom handler to check global shape before squeezing local tensor.
Fixes #166124
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci @aditvenk