🐛 Describe the bug
When a DTensor with Shard(dim) placement on a dynamically-sized dimension is
redistributed to Replicate, the gathered dimension's symbolic shape becomes
n*ceil(s/n) instead of s. This corrupted expression propagates through all
downstream operations (view, transpose, reshape, to_local), causing shape
mismatches in compilation pipelines.
Root cause: Shard._to_replicate_tensor in placement_types.py
- torch.chunk uses ceiling division: local_size = (s + n - 1) // n
- all_gather concatenates n chunks: gathered = n * ((s + n - 1) // n)
- _maybe_unpad_tensor checks "s % n != 0" — evaluates False for even s,
so no unpadding occurs
- Result has symbolic size n*ceil(s/n) ≠ s (SymPy cannot simplify)
Reproduction
Run: torchrun --nproc_per_node=2 repro_redistribute_bug.py
With the following ode as repro_redistribute_bug.py
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import DTensor, Replicate, Shard
from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,
PrepareModuleInput,
RowwiseParallel,
SequenceParallel,
)
class AttentionBlock(nn.Module):
def __init__(self, dim, n_heads, head_dim):
super().__init__()
self.wq = nn.Linear(dim, n_heads * head_dim, bias=False)
self.n_heads = n_heads
self.head_dim = head_dim
def forward(self, x):
q = self.wq(x) # (batch, seqlen, n_heads * head_dim)
bs = q.shape[0]
q = q.view(bs, -1, self.n_heads, self.head_dim)
q = q.transpose(1, 2) # (batch, n_heads, seqlen, head_dim)
return q
class TransformerLayer(nn.Module):
def __init__(self, dim, n_heads, head_dim):
super().__init__()
self.tok_embeddings = nn.Embedding(1000, dim)
self.norm = nn.RMSNorm(dim)
self.attention = AttentionBlock(dim, n_heads, head_dim)
def forward(self, input_ids):
h = self.tok_embeddings(input_ids)
h = self.norm(h)
q = self.attention(h) # DTensor (batch, n_heads, seqlen, head_dim) Shard(1)
if isinstance(q, DTensor):
return q.to_local()
return q
def run():
dist.init_process_group(backend="nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
mesh = init_device_mesh("cuda", (world_size,), mesh_dim_names=("tp",))
model = TransformerLayer(dim=64, n_heads=8, head_dim=8).to(device)
# TP plan: SequenceParallel shards seqlen, PrepareModuleInput gathers it back
parallelize_module(
model,
mesh,
{
"tok_embeddings": RowwiseParallel(
input_layouts=Replicate(),
output_layouts=Shard(1),
use_local_output=False,
),
"norm": SequenceParallel(use_local_output=False),
"attention": PrepareModuleInput(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
),
"attention.wq": ColwiseParallel(use_local_output=False),
},
)
# Custom backend to inspect symbolic shapes in the Dynamo FX graph
def debug_backend(gm, example_inputs):
if rank == 0:
# Use print_readable to get annotated shapes
import io, sys
old_stdout = sys.stdout
sys.stdout = buf = io.StringIO()
try:
gm.print_readable()
finally:
sys.stdout = old_stdout
code = buf.getvalue()
for line in code.split('\n'):
if 'prim_to_local' in line and ':' in line and '=' in line:
print(f" Actual: {line.strip()}")
from torch._inductor.compile_fx import compile_fx
return compile_fx(gm, example_inputs)
import copy
compiled = torch.compile(copy.deepcopy(model), backend=debug_backend, dynamic=True)
input_ids = torch.randint(0, 1000, (1, 100), device=device)
torch._dynamo.mark_dynamic(input_ids, 1) # seqlen is dynamic
if rank == 0:
print("Compiling with dynamic seqlen...")
print("Expected: prim_to_local shape should be [1, 4, s_N, 8]")
print("Bug: prim_to_local shape is [1, 4, 2*(((s_N + 1)//2)), 8]")
print()
out = compiled(input_ids)
if rank == 0:
print(f"\n Runtime output shape: {out.shape}")
# Verify runtime correctness with different seqlen
out2 = compiled(torch.randint(0, 1000, (1, 200), device=device))
if rank == 0:
assert out.shape == torch.Size([1, 4, 100, 8])
assert out2.shape == torch.Size([1, 4, 200, 8])
print(f" Reuse with seqlen=200: {out2.shape}")
print(" Runtime shapes are correct (bug only affects symbolic shapes)")
dist.destroy_process_group()
if __name__ == "__main__":
run()
Versions
Collecting environment information...
PyTorch version: 2.11.0a0+git1c3d76e
Is debug build: False
CUDA used to build PyTorch: 12.9
ROCM used to build PyTorch: N/A
OS: CentOS Stream 9 (x86_64)
GCC version: (GCC) 11.5.0 20240719 (Red Hat 11.5.0-14)
Clang version: Could not collect
CMake version: version 4.2.1
Libc version: glibc-2.34
Python version: 3.12.12 | packaged by Anaconda, Inc. | (main, Oct 21 2025, 20:16:04) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-6.13.2-0_fbk8_0_g8695f611147d-x86_64-with-glibc2.34
Is CUDA available: True
CUDA runtime version: 12.9.86
CUDA_MODULE_LOADING set to:
GPU models and configuration:
GPU 0: NVIDIA H100
GPU 1: NVIDIA H100
GPU 2: NVIDIA H100
GPU 3: NVIDIA H100
GPU 4: NVIDIA H100
GPU 5: NVIDIA H100
GPU 6: NVIDIA H100
GPU 7: NVIDIA H100
Nvidia driver version: 580.82.07
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 52 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 384
On-line CPU(s) list: 0-383
Vendor ID: AuthenticAMD
Model name: AMD EPYC 9654 96-Core Processor
CPU family: 25
Model: 17
Thread(s) per core: 2
Core(s) per socket: 96
Socket(s): 2
Stepping: 1
Frequency boost: enabled
CPU(s) scaling MHz: 100%
CPU max MHz: 2400.0000
CPU min MHz: 1500.0000
BogoMIPS: 4792.67
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl xtopology nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid overflow_recov succor smca fsrm flush_l1d debug_swap
Virtualization: AMD-V
L1d cache: 6 MiB (192 instances)
L1i cache: 6 MiB (192 instances)
L2 cache: 192 MiB (192 instances)
L3 cache: 768 MiB (24 instances)
NUMA node(s): 2
NUMA node0 CPU(s): 0-95,192-287
NUMA node1 CPU(s): 96-191,288-383
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Vulnerable
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1: Vulnerable: __user pointer sanitization and usercopy barriers only; no swapgs barriers
Vulnerability Spectre v2: Vulnerable; IBPB: disabled; STIBP: disabled; PBRSB-eIBRS: Not affected; BHI: Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] flake8==7.3.0
[pip3] flake8-new-union-types==0.4.1
[pip3] flake8-pep604==1.1.0
[pip3] intel-cmplr-lib-ur==2025.3.2
[pip3] intel-openmp==2025.3.2
[pip3] mkl-include==2025.3.1
[pip3] mkl-static==2025.3.1
[pip3] mypy_extensions==1.1.0
[pip3] numpy==2.2.6
[pip3] nvidia-cudnn-frontend==1.18.0
[pip3] onemkl-license==2025.3.1
[pip3] optree==0.18.0
[pip3] tbb==2022.3.1
[pip3] tbb-devel==2022.3.1
[pip3] tcmlib==1.4.1
[pip3] torch==2.11.0a0+git1c3d76e
[pip3] torchdata==0.11.0
[pip3] torchmonarch==0.3.0
[pip3] torchtitan==0.2.1
[pip3] torchvision==0.25.0a0+1e53952
[pip3] torchx-nightly==2026.1.26
[pip3] triton==3.6.0+git9844da95
[pip3] umf==1.0.3
[conda] intel-cmplr-lib-ur 2025.3.2 pypi_0 pypi
[conda] intel-openmp 2025.3.2 pypi_0 pypi
[conda] mkl-include 2025.3.1 pypi_0 pypi
[conda] mkl-static 2025.3.1 pypi_0 pypi
[conda] numpy 2.2.6 pypi_0 pypi
[conda] nvidia-cudnn-frontend 1.18.0 pypi_0 pypi
[conda] onemkl-license 2025.3.1 pypi_0 pypi
[conda] optree 0.18.0 pypi_0 pypi
[conda] tbb 2022.3.1 pypi_0 pypi
[conda] tbb-devel 2022.3.1 pypi_0 pypi
[conda] tcmlib 1.4.1 pypi_0 pypi
[conda] torch 2.11.0a0+git1c3d76e pypi_0 pypi
[conda] torchdata 0.11.0 pypi_0 pypi
[conda] torchmonarch 0.3.0 pypi_0 pypi
[conda] torchtitan 0.2.1 pypi_0 pypi
[conda] torchvision 0.25.0a0+1e53952 dev_0
[conda] torchx-nightly 2026.1.26 pypi_0 pypi
[conda] triton 3.6.0+git9844da95 pypi_0 pypi
[conda] umf 1.0.3 pypi_0 pypi
cc @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci @aditvenk @xmfan @chauhang @penguinwu @ezyang @bobrenjc93 @laithsakka @tianyu-l @XilunWu @SherlockNoMad @ppwwyyxx
🐛 Describe the bug
When a DTensor with Shard(dim) placement on a dynamically-sized dimension is
redistributed to Replicate, the gathered dimension's symbolic shape becomes
n*ceil(s/n) instead of s. This corrupted expression propagates through all
downstream operations (view, transpose, reshape, to_local), causing shape
mismatches in compilation pipelines.
Root cause: Shard._to_replicate_tensor in placement_types.py
so no unpadding occurs
Reproduction
Run:
torchrun --nproc_per_node=2 repro_redistribute_bug.pyWith the following ode as
repro_redistribute_bug.pyVersions
Collecting environment information...
PyTorch version: 2.11.0a0+git1c3d76e
Is debug build: False
CUDA used to build PyTorch: 12.9
ROCM used to build PyTorch: N/A
OS: CentOS Stream 9 (x86_64)
GCC version: (GCC) 11.5.0 20240719 (Red Hat 11.5.0-14)
Clang version: Could not collect
CMake version: version 4.2.1
Libc version: glibc-2.34
Python version: 3.12.12 | packaged by Anaconda, Inc. | (main, Oct 21 2025, 20:16:04) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-6.13.2-0_fbk8_0_g8695f611147d-x86_64-with-glibc2.34
Is CUDA available: True
CUDA runtime version: 12.9.86
CUDA_MODULE_LOADING set to:
GPU models and configuration:
GPU 0: NVIDIA H100
GPU 1: NVIDIA H100
GPU 2: NVIDIA H100
GPU 3: NVIDIA H100
GPU 4: NVIDIA H100
GPU 5: NVIDIA H100
GPU 6: NVIDIA H100
GPU 7: NVIDIA H100
Nvidia driver version: 580.82.07
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 52 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 384
On-line CPU(s) list: 0-383
Vendor ID: AuthenticAMD
Model name: AMD EPYC 9654 96-Core Processor
CPU family: 25
Model: 17
Thread(s) per core: 2
Core(s) per socket: 96
Socket(s): 2
Stepping: 1
Frequency boost: enabled
CPU(s) scaling MHz: 100%
CPU max MHz: 2400.0000
CPU min MHz: 1500.0000
BogoMIPS: 4792.67
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl xtopology nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid overflow_recov succor smca fsrm flush_l1d debug_swap
Virtualization: AMD-V
L1d cache: 6 MiB (192 instances)
L1i cache: 6 MiB (192 instances)
L2 cache: 192 MiB (192 instances)
L3 cache: 768 MiB (24 instances)
NUMA node(s): 2
NUMA node0 CPU(s): 0-95,192-287
NUMA node1 CPU(s): 96-191,288-383
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Vulnerable
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1: Vulnerable: __user pointer sanitization and usercopy barriers only; no swapgs barriers
Vulnerability Spectre v2: Vulnerable; IBPB: disabled; STIBP: disabled; PBRSB-eIBRS: Not affected; BHI: Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] flake8==7.3.0
[pip3] flake8-new-union-types==0.4.1
[pip3] flake8-pep604==1.1.0
[pip3] intel-cmplr-lib-ur==2025.3.2
[pip3] intel-openmp==2025.3.2
[pip3] mkl-include==2025.3.1
[pip3] mkl-static==2025.3.1
[pip3] mypy_extensions==1.1.0
[pip3] numpy==2.2.6
[pip3] nvidia-cudnn-frontend==1.18.0
[pip3] onemkl-license==2025.3.1
[pip3] optree==0.18.0
[pip3] tbb==2022.3.1
[pip3] tbb-devel==2022.3.1
[pip3] tcmlib==1.4.1
[pip3] torch==2.11.0a0+git1c3d76e
[pip3] torchdata==0.11.0
[pip3] torchmonarch==0.3.0
[pip3] torchtitan==0.2.1
[pip3] torchvision==0.25.0a0+1e53952
[pip3] torchx-nightly==2026.1.26
[pip3] triton==3.6.0+git9844da95
[pip3] umf==1.0.3
[conda] intel-cmplr-lib-ur 2025.3.2 pypi_0 pypi
[conda] intel-openmp 2025.3.2 pypi_0 pypi
[conda] mkl-include 2025.3.1 pypi_0 pypi
[conda] mkl-static 2025.3.1 pypi_0 pypi
[conda] numpy 2.2.6 pypi_0 pypi
[conda] nvidia-cudnn-frontend 1.18.0 pypi_0 pypi
[conda] onemkl-license 2025.3.1 pypi_0 pypi
[conda] optree 0.18.0 pypi_0 pypi
[conda] tbb 2022.3.1 pypi_0 pypi
[conda] tbb-devel 2022.3.1 pypi_0 pypi
[conda] tcmlib 1.4.1 pypi_0 pypi
[conda] torch 2.11.0a0+git1c3d76e pypi_0 pypi
[conda] torchdata 0.11.0 pypi_0 pypi
[conda] torchmonarch 0.3.0 pypi_0 pypi
[conda] torchtitan 0.2.1 pypi_0 pypi
[conda] torchvision 0.25.0a0+1e53952 dev_0
[conda] torchx-nightly 2026.1.26 pypi_0 pypi
[conda] triton 3.6.0+git9844da95 pypi_0 pypi
[conda] umf 1.0.3 pypi_0 pypi
cc @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci @aditvenk @xmfan @chauhang @penguinwu @ezyang @bobrenjc93 @laithsakka @tianyu-l @XilunWu @SherlockNoMad @ppwwyyxx