Skip to content

ctx.save_for_backward doesn't save torch.Tensor subclasses fully #47117

@mlamarre

Description

@mlamarre

🐛 Bug

Saving a torch.Tensor subclass with ctx.save_for_backward only saves the base Tensor. The subclass type and additional data is removed (object slicing in C++ terminology).

To Reproduce

Following the Extending PyTorch doc. LoggingTensor is copy-pasted from there.

import torch
import logging
class LoggingTensor(torch.Tensor):
    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        # following line was changed from the documentation to avoid an infinite recursion because of __repr__
        logging.info(f"func: {func.__name__}")
        if kwargs is None:
            kwargs = {}
        return super().__torch_function__(func, types, args, kwargs)

class SquareFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        print('forward x type',type(x),'x data_ptr',x.data_ptr())
        ctx.save_for_backward(x)
        y = torch.mul(x,x)
        return y

    @staticmethod
    def backward(ctx, grad_output):
        x, = ctx.saved_tensors
        print('backward x type',type(x),'x data_ptr',x.data_ptr())
        return 2*x*grad_output

lt = LoggingTensor(torch.rand((3,3)))
lt.requires_grad_(True)
y = SquareFunction.apply(lt)
y.backward(torch.ones_like(y))
assert(lt.grad is not None) # that works

Expected behavior

I would expect the subclass type to be preserved.

Expected:

forward x type <class '__main__.LoggingTensor'> x data_ptr 1715819930816
backward x type <class '__main__.LoggingTensor'> x data_ptr 1715819930816

Current result:

forward x type <class '__main__.LoggingTensor'> x data_ptr 1715819930816
backward x type <class 'torch.Tensor'> x data_ptr 1715819930816

Environment

Collecting environment information...
PyTorch version: 1.7.0
Is debug build: True
CUDA used to build PyTorch: 11.0
ROCM used to build PyTorch: N/A

OS: Microsoft Windows 10 Enterprise
CMake version: version 3.18.0

Python version: 3.7 (64-bit runtime)
Is CUDA available: True

Versions of relevant libraries:
[pip] numpy==1.18.1
[pip] torch==1.7.0

[conda] blas                      1.0                         mkl
[conda] cudatoolkit               11.0.221             h74a9793_0
[conda] mkl                       2020.1                      216
[conda] mkl-service               2.3.0            py37hb782905_0
[conda] mkl_fft                   1.0.15           py37h14836fe_0
[conda] mkl_random                1.1.1            py37h47e9c7a_0
[conda] numpy                     1.18.1           py37h93ca92e_0
[conda] numpy-base                1.18.1           py37hc3f5095_1
[conda] pytorch                   1.7.0           py3.7_cuda110_cudnn8_0    pytorch

cc @hameerabbasi @rgommers @ezyang @albanD @zou3519 @gqchen @pearu @nikitaved @jlin27 @mruberry

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNot as big of a feature, but technically not a bug. Should be easy to fixmodule: __torch_function__module: autogradRelated to torch.autograd, and the autograd engine in generalmodule: docsRelated to our documentation, both in docs/ and docblocksneeds designWe want to add this feature but we need to figure out how firsttriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions