Skip to content

[DTensor] squeeze_.dim updates spec but doesn't squeeze local tensor #174136

@stmcgovern

Description

@stmcgovern

When squeezing on a sharded dimension, squeeze_.dim on DTensor updates the DTensor's spec/shape but does NOT actually execute squeeze on the underlying local tensor, causing a metadata mismatch.

Reproduction

torchrun --nproc_per_node=2 repro.py
import torch
import torch.distributed as dist
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import distribute_tensor, Shard

def main():
    dist.init_process_group(backend='gloo')
    rank = dist.get_rank()
    mesh = init_device_mesh('cpu', (2,))

    x = torch.arange(4).reshape(1, 4).float()
    dt = distribute_tensor(x, mesh, [Shard(0)])

    print(f'[rank {rank}] BEFORE: DTensor shape={dt.shape}, local={dt._local_tensor.shape}')
    dt.squeeze_(0)
    print(f'[rank {rank}] AFTER:  DTensor shape={dt.shape}, local={dt._local_tensor.shape}')

    dist.destroy_process_group()

if __name__ == "__main__":
    main()

Output:

[rank 0] BEFORE: DTensor shape=torch.Size([1, 4]), local=torch.Size([1, 4])
[rank 0] AFTER:  DTensor shape=torch.Size([4]), local=torch.Size([1, 4])

DTensor claims shape [4] but local tensor is still [1, 4].

Root Cause

In _dispatch.py, _dispatch_fast_path_python_tail has special handling for squeeze_.dim:

if op_call == aten.squeeze_.dim:
    args[0]._spec = output_spec
    return return_and_correct_aliasing(op_call, args, kwargs, args[0])

This updates the spec but the local tensor is never squeezed. Manual squeeze works:

dt._local_tensor.squeeze_(0)  # Works - shape becomes [4]

But dt.squeeze_(0) leaves local tensor unchanged.

cc @ezyang @gchanan @kadeng @msaroufim @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @dcci @aditvenk @xmfan @tianyu-l @XilunWu @SherlockNoMad @ppwwyyxx @H-Huang

Metadata

Metadata

Assignees

No one assigned

    Labels

    bot-mislabeledIf you notice that the claude-bot mislabeled an issuebot-triagedThis is a label only to be used by the auto triage bothigh prioritymodule: correctness (silent)issue that returns an incorrect result silentlymodule: dtensordistributed tensor tagoncall: distributedAdd this issue/PR to distributed oncall triage queuetriage review

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions