🐛 Describe the bug
The documentation for all_gather_into_tensor states that it supports either concatenation or stacking of the input tensors. But all_gather_tensor_inplace, which is what Dynamo remaps to, simply calls the functional version and copies the result to output_tensor. The functional version performs a concatenation, hence stacking is unsupported when compiling.
To fix, modify the last line of all_gather_tensor_inplace from
return output_tensor.copy_(all_gather_tensor(input_tensor, gather_dim, group, tag))
to
result = all_gather_tensor(input_tensor, gather_dim, group, tag)
if result.shape == output_tensor.shape:
return output_tensor.copy_(result)
else:
stacked_result = torch.stack(torch.split(result, input_tensor.shape[0], dim=0), dim=0)
if stacked_result.shape == output_tensor.shape:
return output_tensor.copy_(stacked_result)
else:
raise ValueError("informative error message goes here")
The following code snippet, which is from the documentation, fails with torch.compile
tensor_in = torch.arange(2, dtype=torch.int64, device=device)
tensor_out = torch.zeros(world_size, 2, dtype=torch.int64, device=device)
dist.all_gather_into_tensor(t_out, t_in)
On 4 XLA devices (TPUs), the error is below. I expect a similar failure on GPUs.
torch._dynamo.exc.TorchRuntimeError: Dynamo failed to run FX node with fake tensors: call_method copy_(*(FakeTensor(..., device='xla:0', size=(4, 2)), FakeTensor(..., device='xla:0', size=(8,))), **{}): got RuntimeError('expand: attempting to expand a dimension of length 8 -> 2!')
Versions
Collecting environment information...
PyTorch version: N/A
Is debug build: N/A
CUDA used to build PyTorch: N/A
ROCM used to build PyTorch: N/A
OS: Debian GNU/Linux 11 (bullseye) (x86_64)
GCC version: (Debian 10.2.1-6) 10.2.1 20210110
Clang version: Could not collect
CMake version: version 3.31.6
Libc version: glibc-2.31
Python version: 3.10.17 (main, Apr 9 2025, 18:21:04) [GCC 10.2.1 20210110] (64-bit runtime)
Python platform: Linux-5.19.0-1022-gcp-x86_64-with-glibc2.31
Is CUDA available: N/A
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: N/A
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 48 bits physical, 48 bits virtual
CPU(s): 240
On-line CPU(s) list: 0-239
Thread(s) per core: 2
Core(s) per socket: 60
Socket(s): 2
NUMA node(s): 2
Vendor ID: AuthenticAMD
CPU family: 23
Model: 49
Model name: AMD EPYC 7B12
Stepping: 0
CPU MHz: 2249.998
BogoMIPS: 4499.99
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 3.8 MiB
L1i cache: 3.8 MiB
L2 cache: 60 MiB
L3 cache: 480 MiB
NUMA node0 CPU(s): 0-59,120-179
NUMA node1 CPU(s): 60-119,180-239
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Mitigation; untrained return thunk; SMT enabled with STIBP protection
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid tsc_known_freq pni pclmulqdq ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw topoext ssbd ibrs ibpb stibp vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr arat npt nrip_save umip rdpid
Versions of relevant libraries:
[pip3] numpy==2.2.5
[pip3] torch==2.8.0a0+gitb44306d
[pip3] torch-xla==2.8.0+gitd82e15c
[pip3] torchvision==0.22.0a0+5f03dc5
[conda] Could not collect
cc @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci @aditvenk @xmfan @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @kadeng @amjames @Lucaskabela @jataylo @H-Huang @chenyang78
🐛 Describe the bug
The documentation for all_gather_into_tensor states that it supports either concatenation or stacking of the input tensors. But all_gather_tensor_inplace, which is what Dynamo remaps to, simply calls the functional version and copies the result to
output_tensor. The functional version performs a concatenation, hence stacking is unsupported when compiling.To fix, modify the last line of all_gather_tensor_inplace from
to
The following code snippet, which is from the documentation, fails with torch.compile
On 4 XLA devices (TPUs), the error is below. I expect a similar failure on GPUs.
Versions
Collecting environment information...
PyTorch version: N/A
Is debug build: N/A
CUDA used to build PyTorch: N/A
ROCM used to build PyTorch: N/A
OS: Debian GNU/Linux 11 (bullseye) (x86_64)
GCC version: (Debian 10.2.1-6) 10.2.1 20210110
Clang version: Could not collect
CMake version: version 3.31.6
Libc version: glibc-2.31
Python version: 3.10.17 (main, Apr 9 2025, 18:21:04) [GCC 10.2.1 20210110] (64-bit runtime)
Python platform: Linux-5.19.0-1022-gcp-x86_64-with-glibc2.31
Is CUDA available: N/A
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: N/A
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 48 bits physical, 48 bits virtual
CPU(s): 240
On-line CPU(s) list: 0-239
Thread(s) per core: 2
Core(s) per socket: 60
Socket(s): 2
NUMA node(s): 2
Vendor ID: AuthenticAMD
CPU family: 23
Model: 49
Model name: AMD EPYC 7B12
Stepping: 0
CPU MHz: 2249.998
BogoMIPS: 4499.99
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 3.8 MiB
L1i cache: 3.8 MiB
L2 cache: 60 MiB
L3 cache: 480 MiB
NUMA node0 CPU(s): 0-59,120-179
NUMA node1 CPU(s): 60-119,180-239
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Mitigation; untrained return thunk; SMT enabled with STIBP protection
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid tsc_known_freq pni pclmulqdq ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw topoext ssbd ibrs ibpb stibp vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr arat npt nrip_save umip rdpid
Versions of relevant libraries:
[pip3] numpy==2.2.5
[pip3] torch==2.8.0a0+gitb44306d
[pip3] torch-xla==2.8.0+gitd82e15c
[pip3] torchvision==0.22.0a0+5f03dc5
[conda] Could not collect
cc @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci @aditvenk @xmfan @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @kadeng @amjames @Lucaskabela @jataylo @H-Huang @chenyang78