Skip to content

Error when using pp.func.jacrev #319

@lahavlipson

Description

@lahavlipson

🐛 Describe the bug

The minimal example of pp.func.jacrev works fine, but when I use .Inv() it throws an error. The code that I run is:

import pypose as pp
import torch
def func(pose, points):
    return pose.Inv() @ points
pose = pp.randn_SE3(1)
points = torch.randn(1, 3)
jacobian = pp.func.jacrev(func)(pose, points)

which throws the error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[2], line 7
      5 pose = pp.randn_SE3(1)
      6 points = torch.randn(1, 3)
----> 7 jacobian = pp.func.jacrev(func)(pose, points)

File ~/anaconda3/envs/myenv/lib/python3.10/contextlib.py:79, in ContextDecorator.__call__.<locals>.inner(*args, **kwds)
     76 @wraps(func)
     77 def inner(*args, **kwds):
     78     with self._recreate_cm():
---> 79         return func(*args, **kwds)

File ~/Dropbox/Research/pypose/pypose/func/jac.py:57, in jacrev.<locals>.wrapper_fn(*args, **kwargs)
     55 @retain_ltype()
     56 def wrapper_fn(*args, **kwargs):
---> 57     return jac_func(*args, **kwargs)

File ~/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/_functorch/eager_transforms.py:492, in jacrev.<locals>.wrapper_fn(*args)
    489 @wraps(func)
    490 def wrapper_fn(*args):
    491     error_if_complex("jacrev", args, is_input=True)
--> 492     vjp_out = _vjp_with_argnums(func, *args, argnums=argnums, has_aux=has_aux)
    493     if has_aux:
    494         output, vjp_fn, aux = vjp_out

File ~/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/_functorch/vmap.py:38, in doesnt_support_saved_tensors_hooks.<locals>.fn(*args, **kwargs)
     35 @functools.wraps(f)
     36 def fn(*args, **kwargs):
     37     with torch.autograd.graph.disable_saved_tensors_hooks(message):
---> 38         return f(*args, **kwargs)

File ~/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/_functorch/eager_transforms.py:294, in _vjp_with_argnums(func, argnums, has_aux, *primals)
    292     diff_primals = _slice_argnums(primals, argnums, as_tuple=False)
    293     tree_map_(partial(_create_differentiable, level=level), diff_primals)
--> 294 primals_out = func(*primals)
    296 if has_aux:
    297     if not (isinstance(primals_out, tuple) and len(primals_out) == 2):

Cell In[2], line 4, in func(pose, points)
      3 def func(pose, points):
----> 4     return pose.Inv() @ points

File ~/Dropbox/Research/pypose/pypose/lietensor/lietensor.py:993, in LieTensor.Inv(self)
    989 def Inv(self):
    990     r'''
    991     See :meth:`pypose.Inv`
    992     '''
--> 993     return self.ltype.Inv(self)

File ~/Dropbox/Research/pypose/pypose/lietensor/lietensor.py:374, in SE3Type.Inv(self, X)
    372 def Inv(self, X):
    373     X = X.tensor() if hasattr(X, 'ltype') else X
--> 374     out = SE3_Inv.apply(X)
    375     return LieTensor(out, ltype=SE3_type)

File ~/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/autograd/function.py:542, in Function.apply(cls, *args, **kwargs)
    539     return super().apply(*args, **kwargs)  # type: ignore[misc]
    541 if cls.setup_context == _SingleLevelFunction.setup_context:
--> 542     raise RuntimeError(
    543         "In order to use an autograd.Function with functorch transforms "
    544         "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
    545         "staticmethod. For more details, please see "
    546         "https://pytorch.org/docs/master/notes/extending.func.html"
    547     )
    549 return custom_function_call(cls, *args, **kwargs)

RuntimeError: In order to use an autograd.Function with functorch transforms (vmap, grad, jvp, jacrev, ...), it must override the setup_context staticmethod. For more details, please see https://pytorch.org/docs/master/notes/extending.func.html

This issue seems related to #262

Versions

I am using Pytorch 2.1.0, and the PyPose on commit 375740e (origin/main as of creating this issue)

Collecting environment information...
PyTorch version: 2.1.0
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Pop!_OS 22.04 LTS (x86_64)
GCC version: (Ubuntu 11.2.0-19ubuntu1) 11.2.0
Clang version: Could not collect
CMake version: version 3.22.4
Libc version: glibc-2.35

Python version: 3.10.11 | packaged by conda-forge | (main, May 10 2023, 18:58:44) [GCC 11.3.0] (64-bit runtime)
Python platform: Linux-5.19.0-76051900-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 11.6.124
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3090
Nvidia driver version: 515.65.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] functorch==1.14.0a0+b71aa0b
[pip3] lietorch==0.2
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.22.4
[pip3] pytorch3d==0.7.5
[pip3] torch==2.1.0
[pip3] torch-scatter==2.1.2
[pip3] torchaudio==2.1.0
[pip3] torchvision==0.16.0
[conda] blas                      1.0                         mkl    conda-forge
[conda] cudatoolkit               11.3.1               h2bc3f7f_2  
[conda] ffmpeg                    4.3                  hf484d3e_0    pytorch
[conda] functorch                 1.14.0a0+b71aa0b           dev_0    <develop>
[conda] libblas                   3.9.0            12_linux64_mkl    conda-forge
[conda] libcblas                  3.9.0            12_linux64_mkl    conda-forge
[conda] libjpeg-turbo             2.0.0                h9bf148f_0    pytorch
[conda] liblapack                 3.9.0            12_linux64_mkl    conda-forge
[conda] lietorch                  0.2                      pypi_0    pypi
[conda] mkl                       2021.4.0           h06a4308_640  
[conda] mkl-service               2.4.0           py310h7f8727e_0  
[conda] mkl_fft                   1.3.1           py310hd6ae3a3_0  
[conda] mkl_random                1.2.2           py310h00e6091_0  
[conda] numpy                     1.22.4          py310h4ef5377_0    conda-forge
[conda] pytorch                   2.1.0           py3.10_cuda11.8_cudnn8.7.0_0    pytorch
[conda] pytorch-cuda              11.8                 h7e8668a_5    pytorch
[conda] pytorch-mutex             1.0                        cuda    pytorch
[conda] pytorch-scatter           2.1.2           py310_torch_2.1.0_cu118    pyg
[conda] pytorch3d                 0.7.5           py310_cu118_pyt210    pytorch3d
[conda] torch                     2.1.1                    pypi_0    pypi
[conda] torchaudio                2.1.0               py310_cu118    pytorch
[conda] torchtriton               2.1.0                     py310    pytorch
[conda] torchvision               0.16.0              py310_cu118    pytorch

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions