Skip to content

LieTensor becomes torch.Tensor in DataParrallel #258

@zeroAska

Description

@zeroAska

🐛 Describe the bug

I'm trying to use run DataParrallel for mulit-gpu training with pypose's LieTensor. However, I noticed that when I sent a LieTensor pose object to the data parrallel model, the pose becomes a torch.Tensor:

Example:

import torch                                                                                                                              
import torch.nn                                                                                                                           
import pypose as pp                                                                                                                       
class DataParallelModel(nn.Module):                                                                                                       
    def __init__(self):                                                                                                                   
        super().__init__()                                                                                                                
                                                                                                                                          
    def forward(self, x, T):          
        print("type of T is ",type(T))                                                                                                    
        assert(isinstance(T, pp.LieTensor))                                                                                               
        return T @ x                                                                                                                      
                                                                                                                                          
def test_parallel():                                                                                                                      
    a = pp.randn_SO3(4)                                                                                                                   
    a.requires_grad = True                                                                                                                
                                                                                                                                          
    net = DataParallelModel()                                                                                                             
    if torch.cuda.device_count() > 1:                                                                                                     
        print("Let's use", torch.cuda.device_count(), "GPUs!")                                                                            
        net = nn.DataParallel(net).cuda()                                                                                                 
        net(torch.rand(4), a)                                                                                                             
                                                                                                                                          
test_parallel()

Output:

Let's use 4 GPUs!
type of T is  <class 'torch.Tensor'>
type of T is  <class 'torch.Tensor'>
type of T is  <class 'torch.Tensor'>
type of T is  <class 'torch.Tensor'>
....
Original Traceback (most recent call last):
  File ".../python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 64, in _wo
rker
    output = module(*input, **kwargs)
  File ".../python3.7/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "test_backprop.py", line 65, in forward
    assert(isinstance(T, pp.LieTensor))
AssertionError  

Versions

% python collect_env.py 
Collecting environment information...
PyTorch version: 1.13.1+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

GCC version: (GCC) 7.3.1 
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.2.5

Python version: 3.7.16
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration: 
GPU 0: Tesla V100-SXM2-16GB
GPU 1: Tesla V100-SXM2-16GB
GPU 2: Tesla V100-SXM2-16GB
GPU 3: Tesla V100-SXM2-16GB

Nvidia driver version: 515.105.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] lietorch==0.4
[pip3] numpy==1.21.6
[pip3] numpydoc==1.5.0
[pip3] torch==1.13.1
[pip3] torch-vision==0.1.6.dev0
[conda] Could not collect

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions