🐛 Bug
The function _object_to_tensor is much slower for large objects since PyTorch 1.9. This affects torch.distributed.all_gather_object and related methods.
To Reproduce
Running
from torch.distributed.distributed_c10d import _object_to_tensor
import time
start = time.time()
_object_to_tensor("x" * 50_000_000)
print("Time:", time.time() - start)
prints
PyTorch 1.8.1:
Time: 0.06174063682556152
PyTorch 1.9.1:
Expected behavior
I would expect both versions to be roughly equally fast.
Environment
Collecting environment information...
PyTorch version: 1.9.1+cu102
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A
OS: Pop!_OS 20.04 LTS (x86_64)
GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Clang version: 10.0.0-4ubuntu1
CMake version: version 3.16.3
Libc version: glibc-2.31
Python version: 3.9.1 (default, Dec 11 2020, 14:32:07) [GCC 7.3.0] (64-bit runtime)
Python platform: Linux-5.13.0-7614-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 10.1.105
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 2080 Ti
Nvidia driver version: 470.63.01
cuDNN version: Probably one of the following:
/usr/lib/cuda-10.1/targets/x86_64-linux/lib/libcudnn.so.8.0.4
/usr/lib/cuda-10.1/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.0.4
/usr/lib/cuda-10.1/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.0.4
/usr/lib/cuda-10.1/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.0.4
/usr/lib/cuda-10.1/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.0.4
/usr/lib/cuda-10.1/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.0.4
/usr/lib/cuda-10.1/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.0.4
HIP runtime version: N/A
MIOpen runtime version: N/A
Versions of relevant libraries:
[pip3] efficientnet-pytorch==0.6.3
[pip3] mypy==0.910
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.21.1
[pip3] pytorch-lightning==1.4.2
[pip3] segmentation-models-pytorch==0.2.0
[pip3] torch==1.9.1
[pip3] torchaudio==0.9.0
[pip3] torchmetrics==0.5.0
[pip3] torchvision==0.10.0
[conda] efficientnet-pytorch 0.6.3 pypi_0 pypi
[conda] mypy 0.910 pypi_0 pypi
[conda] mypy-extensions 0.4.3 pypi_0 pypi
[conda] numpy 1.21.1 pypi_0 pypi
[conda] pytorch-lightning 1.4.2 pypi_0 pypi
[conda] segmentation-models-pytorch 0.2.0 pypi_0 pypi
[conda] torch 1.9.1 pypi_0 pypi
[conda] torchaudio 0.9.0 pypi_0 pypi
[conda] torchmetrics 0.5.0 pypi_0 pypi
[conda] torchvision 0.10.0 pypi_0 pypi
Additional context
The regression seems to have happened in this commit: ce05b7a. Apparently byte_tensor = torch.tensor(byte_storage, dtype=torch.uint8) is much slower than byte_tensor = torch.ByteTensor(byte_storage):
import torch
import io
import time
import pickle
f = io.BytesIO()
pickle.dump("x" * 50_000_000, f)
byte_storage = torch.ByteStorage.from_buffer(f.getvalue())
start = time.time()
byte_tensor = torch.tensor(byte_storage, dtype=torch.uint8)
print("torch.tensor time:", time.time() - start)
start = time.time()
byte_tensor = torch.ByteTensor(byte_storage)
print("torch.ByteTensor time:", time.time() - start)
prints
torch.tensor time: 6.084373235702515
torch.ByteTensor time: 0.001978158950805664
cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @SciPioneer @H-Huang @gcramer23
🐛 Bug
The function
_object_to_tensoris much slower for large objects since PyTorch 1.9. This affectstorch.distributed.all_gather_objectand related methods.To Reproduce
Running
prints
PyTorch 1.8.1:
PyTorch 1.9.1:
Expected behavior
I would expect both versions to be roughly equally fast.
Environment
Additional context
The regression seems to have happened in this commit: ce05b7a. Apparently
byte_tensor = torch.tensor(byte_storage, dtype=torch.uint8)is much slower thanbyte_tensor = torch.ByteTensor(byte_storage):prints
cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @SciPioneer @H-Huang @gcramer23