Skip to content

[DTensor] squeeze() causes incorrect dimension when sharded dimension size equals mesh size. #166124

@mansiag05

Description

@mansiag05

🐛 Describe the bug

DTensor.squeeze() causes incorrect dimension when a sharded dimension's size equals the mesh size.

import os
import torch
import torch.distributed as dist
from torch.distributed._tensor import DeviceMesh, Shard, distribute_tensor

def main():
    rank = int(os.environ.get("RANK", 0))
    world_size = int(os.environ.get("WORLD_SIZE", 4))
    dist.init_process_group(backend="gloo")
    mesh = DeviceMesh("cpu", torch.arange(world_size))
    
    # Create tensor with shape [4, 8] where dim 0 size (4) == world_size (4)
    global_tensor = torch.arange(32).reshape(4, 8).float() * 10 + rank
    
    if rank == 0:
        print("=" * 60)
        print(f"Global tensor shape: {global_tensor.shape}")
        print(f"World size: {world_size}")
    
    # Shard along dimension 0
    dtensor = distribute_tensor(global_tensor, mesh, [Shard(0)])
    local_shape_before = dtensor._local_tensor.shape
    if rank == 0:
        print(f"Global shape: {dtensor.shape}")
        print(f"Local shape on each rank: {dtensor._local_tensor.shape}")
        print(f"Sharding: {dtensor.placements}")

    squeezed = dtensor.squeeze()
    local_shape_after = squeezed._local_tensor.shape
    if rank == 0:
        print()
        print(f"After squeeze:")
        print(f"Global shape: {squeezed.shape}")
        print(f"Local shape on each rank: {squeezed._local_tensor.shape}")
        print(f"Sharding: {squeezed.placements}")
        print()
    
    if local_shape_before != local_shape_after:
        if rank == 0:
            print(f"Local tensor shape changed from {local_shape_before} to {local_shape_after}")    

    try:
        if rank == 0:
            print("Attempting to gather full tensor...")
        full = squeezed.full_tensor()
        
        if rank == 0:
            print(f"Full tensor shape: {full.shape}")
            print(f"Expected: {global_tensor.shape}")
            if full.shape != global_tensor.shape:
                print("SHAPE MISMATCH")
    except Exception as e:
        if rank == 0:
            print(f"FAILED with error: {type(e).__name__}")    
    dist.destroy_process_group()

if __name__ == "__main__":
    main()

Bug:

  • After squeeze(): Local shape becomes [8] instead of [1, 8]
  • full_tensor() returns [32] instead of [4, 8]

Versions

Collecting environment information...
PyTorch version: 2.10.0a0+gitd73c283
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: Fedora Linux 42 (Workstation Edition) (x86_64)
GCC version: (GCC) 15.2.1 20250808 (Red Hat 15.2.1-1)
Clang version: Could not collect
CMake version: version 4.0.0
Libc version: glibc-2.41

Python version: 3.13.7 (main, Aug 14 2025, 00:00:00) [GCC 15.2.1 20250808 (Red Hat 15.2.1-1)] (64-bit runtime)
Python platform: Linux-6.16.10-200.fc42.x86_64-x86_64-with-glibc2.41
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 48 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 16
On-line CPU(s) list: 0-15
Vendor ID: AuthenticAMD
Model name: AMD Ryzen 7 PRO 7840HS w/ Radeon 780M Graphics
CPU family: 25
Model: 116
Thread(s) per core: 2
Core(s) per socket: 8
Socket(s): 1
Stepping: 1
Frequency boost: enabled
CPU(s) scaling MHz: 60%
CPU max MHz: 5137.9038
CPU min MHz: 419.4210
BogoMIPS: 7585.32
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl xtopology nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid overflow_recov succor smca fsrm flush_l1d amd_lbr_pmc_freeze
Virtualization: AMD-V
L1d cache: 256 KiB (8 instances)
L1i cache: 256 KiB (8 instances)
L2 cache: 8 MiB (8 instances)
L3 cache: 16 MiB (1 instance)
NUMA node(s): 1
NUMA node0 CPU(s): 0-15

Versions of relevant libraries:
[pip3] flake8==7.3.0
[pip3] flake8-bugbear==24.12.12
[pip3] flake8-comprehensions==3.16.0
[pip3] flake8-executable==2.1.3
[pip3] flake8-logging-format==2024.24.12
[pip3] flake8-pyi==25.5.0
[pip3] flake8_simplify==0.22.0
[pip3] mypy==1.16.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==2.1.0
[pip3] onnx==1.18.0
[pip3] onnx-ir==0.1.9
[pip3] onnxscript==0.4.0
[pip3] optree==0.17.0
[pip3] pytorch_sphinx_theme2==0.1.0
[pip3] torch==2.10.0a0+git7c39b2e

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci @tianyu-l @XilunWu @SherlockNoMad

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: dtensordistributed tensor tagoncall: distributedAdd this issue/PR to distributed oncall triage queue

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions