-
-
Notifications
You must be signed in to change notification settings - Fork 116
Closed
Description
🐛 Describe the bug
The minimal example of pp.func.jacrev works fine, but when I use .Inv() it throws an error. The code that I run is:
import pypose as pp
import torch
def func(pose, points):
return pose.Inv() @ points
pose = pp.randn_SE3(1)
points = torch.randn(1, 3)
jacobian = pp.func.jacrev(func)(pose, points)which throws the error:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[2], line 7
5 pose = pp.randn_SE3(1)
6 points = torch.randn(1, 3)
----> 7 jacobian = pp.func.jacrev(func)(pose, points)
File ~/anaconda3/envs/myenv/lib/python3.10/contextlib.py:79, in ContextDecorator.__call__.<locals>.inner(*args, **kwds)
76 @wraps(func)
77 def inner(*args, **kwds):
78 with self._recreate_cm():
---> 79 return func(*args, **kwds)
File ~/Dropbox/Research/pypose/pypose/func/jac.py:57, in jacrev.<locals>.wrapper_fn(*args, **kwargs)
55 @retain_ltype()
56 def wrapper_fn(*args, **kwargs):
---> 57 return jac_func(*args, **kwargs)
File ~/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/_functorch/eager_transforms.py:492, in jacrev.<locals>.wrapper_fn(*args)
489 @wraps(func)
490 def wrapper_fn(*args):
491 error_if_complex("jacrev", args, is_input=True)
--> 492 vjp_out = _vjp_with_argnums(func, *args, argnums=argnums, has_aux=has_aux)
493 if has_aux:
494 output, vjp_fn, aux = vjp_out
File ~/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/_functorch/vmap.py:38, in doesnt_support_saved_tensors_hooks.<locals>.fn(*args, **kwargs)
35 @functools.wraps(f)
36 def fn(*args, **kwargs):
37 with torch.autograd.graph.disable_saved_tensors_hooks(message):
---> 38 return f(*args, **kwargs)
File ~/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/_functorch/eager_transforms.py:294, in _vjp_with_argnums(func, argnums, has_aux, *primals)
292 diff_primals = _slice_argnums(primals, argnums, as_tuple=False)
293 tree_map_(partial(_create_differentiable, level=level), diff_primals)
--> 294 primals_out = func(*primals)
296 if has_aux:
297 if not (isinstance(primals_out, tuple) and len(primals_out) == 2):
Cell In[2], line 4, in func(pose, points)
3 def func(pose, points):
----> 4 return pose.Inv() @ points
File ~/Dropbox/Research/pypose/pypose/lietensor/lietensor.py:993, in LieTensor.Inv(self)
989 def Inv(self):
990 r'''
991 See :meth:`pypose.Inv`
992 '''
--> 993 return self.ltype.Inv(self)
File ~/Dropbox/Research/pypose/pypose/lietensor/lietensor.py:374, in SE3Type.Inv(self, X)
372 def Inv(self, X):
373 X = X.tensor() if hasattr(X, 'ltype') else X
--> 374 out = SE3_Inv.apply(X)
375 return LieTensor(out, ltype=SE3_type)
File ~/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/autograd/function.py:542, in Function.apply(cls, *args, **kwargs)
539 return super().apply(*args, **kwargs) # type: ignore[misc]
541 if cls.setup_context == _SingleLevelFunction.setup_context:
--> 542 raise RuntimeError(
543 "In order to use an autograd.Function with functorch transforms "
544 "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
545 "staticmethod. For more details, please see "
546 "https://pytorch.org/docs/master/notes/extending.func.html"
547 )
549 return custom_function_call(cls, *args, **kwargs)
RuntimeError: In order to use an autograd.Function with functorch transforms (vmap, grad, jvp, jacrev, ...), it must override the setup_context staticmethod. For more details, please see https://pytorch.org/docs/master/notes/extending.func.html
This issue seems related to #262
Versions
I am using Pytorch 2.1.0, and the PyPose on commit 375740e (origin/main as of creating this issue)
Collecting environment information...
PyTorch version: 2.1.0
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A
OS: Pop!_OS 22.04 LTS (x86_64)
GCC version: (Ubuntu 11.2.0-19ubuntu1) 11.2.0
Clang version: Could not collect
CMake version: version 3.22.4
Libc version: glibc-2.35
Python version: 3.10.11 | packaged by conda-forge | (main, May 10 2023, 18:58:44) [GCC 11.3.0] (64-bit runtime)
Python platform: Linux-5.19.0-76051900-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 11.6.124
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3090
Nvidia driver version: 515.65.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] functorch==1.14.0a0+b71aa0b
[pip3] lietorch==0.2
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.22.4
[pip3] pytorch3d==0.7.5
[pip3] torch==2.1.0
[pip3] torch-scatter==2.1.2
[pip3] torchaudio==2.1.0
[pip3] torchvision==0.16.0
[conda] blas 1.0 mkl conda-forge
[conda] cudatoolkit 11.3.1 h2bc3f7f_2
[conda] ffmpeg 4.3 hf484d3e_0 pytorch
[conda] functorch 1.14.0a0+b71aa0b dev_0 <develop>
[conda] libblas 3.9.0 12_linux64_mkl conda-forge
[conda] libcblas 3.9.0 12_linux64_mkl conda-forge
[conda] libjpeg-turbo 2.0.0 h9bf148f_0 pytorch
[conda] liblapack 3.9.0 12_linux64_mkl conda-forge
[conda] lietorch 0.2 pypi_0 pypi
[conda] mkl 2021.4.0 h06a4308_640
[conda] mkl-service 2.4.0 py310h7f8727e_0
[conda] mkl_fft 1.3.1 py310hd6ae3a3_0
[conda] mkl_random 1.2.2 py310h00e6091_0
[conda] numpy 1.22.4 py310h4ef5377_0 conda-forge
[conda] pytorch 2.1.0 py3.10_cuda11.8_cudnn8.7.0_0 pytorch
[conda] pytorch-cuda 11.8 h7e8668a_5 pytorch
[conda] pytorch-mutex 1.0 cuda pytorch
[conda] pytorch-scatter 2.1.2 py310_torch_2.1.0_cu118 pyg
[conda] pytorch3d 0.7.5 py310_cu118_pyt210 pytorch3d
[conda] torch 2.1.1 pypi_0 pypi
[conda] torchaudio 2.1.0 py310_cu118 pytorch
[conda] torchtriton 2.1.0 py310 pytorch
[conda] torchvision 0.16.0 py310_cu118 pytorch
Metadata
Metadata
Assignees
Labels
No labels