-
-
Notifications
You must be signed in to change notification settings - Fork 116
Closed
Description
🐛 Describe the bug
I'm trying to use run DataParrallel for mulit-gpu training with pypose's LieTensor. However, I noticed that when I sent a LieTensor pose object to the data parrallel model, the pose becomes a torch.Tensor:
Example:
import torch
import torch.nn
import pypose as pp
class DataParallelModel(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, T):
print("type of T is ",type(T))
assert(isinstance(T, pp.LieTensor))
return T @ x
def test_parallel():
a = pp.randn_SO3(4)
a.requires_grad = True
net = DataParallelModel()
if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
net = nn.DataParallel(net).cuda()
net(torch.rand(4), a)
test_parallel()
Output:
Let's use 4 GPUs!
type of T is <class 'torch.Tensor'>
type of T is <class 'torch.Tensor'>
type of T is <class 'torch.Tensor'>
type of T is <class 'torch.Tensor'>
....
Original Traceback (most recent call last):
File ".../python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 64, in _wo
rker
output = module(*input, **kwargs)
File ".../python3.7/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "test_backprop.py", line 65, in forward
assert(isinstance(T, pp.LieTensor))
AssertionError
Versions
% python collect_env.py
Collecting environment information...
PyTorch version: 1.13.1+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A
GCC version: (GCC) 7.3.1
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.2.5
Python version: 3.7.16
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration:
GPU 0: Tesla V100-SXM2-16GB
GPU 1: Tesla V100-SXM2-16GB
GPU 2: Tesla V100-SXM2-16GB
GPU 3: Tesla V100-SXM2-16GB
Nvidia driver version: 515.105.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] lietorch==0.4
[pip3] numpy==1.21.6
[pip3] numpydoc==1.5.0
[pip3] torch==1.13.1
[pip3] torch-vision==0.1.6.dev0
[conda] Could not collect
Metadata
Metadata
Assignees
Labels
No labels