Skip to content

Commit f9a0d0c

Browse files
hameerabbasifacebook-github-bot
authored andcommitted
Allow Tensor-likes in torch.autograd.gradcheck (#43877)
Summary: Fixes #42942 Pull Request resolved: #43877 Reviewed By: zou3519 Differential Revision: D23493257 Pulled By: ezyang fbshipit-source-id: 6cdaabe17157b484e9491189706ccc15420ac239
1 parent c8914af commit f9a0d0c

4 files changed

Lines changed: 174 additions & 24 deletions

File tree

test/test_overrides.py

Lines changed: 142 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -644,27 +644,107 @@ def test(self):
644644

645645
generate_tensor_like_override_tests(TestTorchFunctionOverride)
646646

647-
class TestEinsumOverride(TestCase):
648-
"Regression test for gh-38479"
649-
def test_wrapper(self):
650-
class Wrapper():
651-
"Basic data container that knows how to unwrap itself"
652-
def __init__(self, data):
653-
self.data = data
647+
class Wrapper:
648+
"Basic data container that knows how to unwrap itself"
649+
def __init__(self, data):
650+
self.__dict__["_data"] = data
651+
self.__dict__["used_attrs"] = set()
652+
self.__dict__["used_calls"] = set()
653+
654+
def __getattr__(self, name):
655+
if name in self.__dict__:
656+
return self.__dict__[name]
657+
self.used_attrs.add(name)
658+
659+
val = getattr(self._data, name)
660+
661+
# If it's a method
662+
if callable(val):
663+
c = getattr(type(self._data), name)
664+
# Don't append self to args if classmethod/staticmethod
665+
if c is val:
666+
return lambda *a, **kw: wrap(self.__torch_function__(c, (Wrapper,), args=a, kwargs=kw))
667+
# Otherwise append self to args
668+
return lambda *a, **kw: wrap(self.__torch_function__(c, (Wrapper,), args=(self,) + a, kwargs=kw))
669+
670+
return wrap(val)
671+
672+
def __setattr__(self, name, value):
673+
if name in self.__dict__:
674+
self.__dict__[name] = value
675+
676+
self.used_attrs.add(name)
677+
setattr(self._data, name, unwrap(value))
678+
679+
def __setitem__(self, key, value):
680+
self._data[unwrap(key)] = unwrap(value)
681+
682+
def __getitem__(self, key):
683+
return wrap(self._data[unwrap(key)])
684+
685+
def __torch_function__(self, func, types, args=(), kwargs=None):
686+
if kwargs is None:
687+
kwargs = {}
688+
self.used_calls.add(func)
689+
args = unwrap(tuple(args))
690+
kwargs = {k: unwrap(v) for k, v in kwargs.items()}
691+
692+
return wrap(func(*args, **kwargs))
693+
694+
def __add__(self, other):
695+
return self.__torch_function__(torch.add, (Wrapper,), (self, other))
696+
697+
def __sub__(self, other):
698+
return self.__torch_function__(torch.sub, (Wrapper,), (self, other))
654699

655-
def __torch_function__(self, func, types, args=(), kwargs=None):
656-
if kwargs is None:
657-
kwargs = {}
700+
def __truediv__(self, other):
701+
return self.__torch_function__(torch.true_divide, (Wrapper,), (self, other))
658702

659-
# unwrap inputs if necessary
660-
def unwrap(v):
661-
return v.data if isinstance(v, Wrapper) else v
703+
def __floordiv__(self, other):
704+
return self.__torch_function__(torch.floor_divide, (Wrapper,), (self, other))
662705

663-
args = map(unwrap, args)
664-
kwargs = {k: unwrap(v) for k, v in kwargs.items()}
706+
def __ge__(self, other):
707+
return self.__torch_function__(torch.ge, (Wrapper,), (self, other))
665708

666-
return func(*args, **kwargs)
709+
def __gt__(self, other):
710+
return self.__torch_function__(torch.gt, (Wrapper,), (self, other))
667711

712+
def __lt__(self, other):
713+
return self.__torch_function__(torch.lt, (Wrapper,), (self, other))
714+
715+
def __le__(self, other):
716+
return self.__torch_function__(torch.le, (Wrapper,), (self, other))
717+
718+
def __eq__(self, other):
719+
return self.__torch_function__(torch.eq, (Wrapper,), (self, other))
720+
721+
def __ne__(self, other):
722+
return self.__torch_function__(torch.ne, (Wrapper,), (self, other))
723+
724+
def __bool__(self):
725+
return self.__torch_function__(torch.Tensor.__bool__, (Wrapper,), (self,))
726+
727+
def __int__(self):
728+
return self.__torch_function__(torch.Tensor.__int__, (Wrapper,), (self,))
729+
730+
731+
# unwrap inputs if necessary
732+
def unwrap(v):
733+
if type(v) in {tuple, list}:
734+
return type(v)(unwrap(vi) for vi in v)
735+
736+
return v._data if isinstance(v, Wrapper) else v
737+
738+
# wrap inputs if necessary
739+
def wrap(v):
740+
if type(v) in {tuple, list}:
741+
return type(v)(wrap(vi) for vi in v)
742+
743+
return Wrapper(v) if isinstance(v, torch.Tensor) else v
744+
745+
class TestEinsumOverride(TestCase):
746+
"Regression test for gh-38479"
747+
def test_wrapper(self):
668748
x = Wrapper(torch.randn(5))
669749
y = Wrapper(torch.randn(4))
670750
self.assertTrue(torch.allclose(torch.einsum('i,j->ij', x, y),
@@ -678,5 +758,51 @@ def unwrap(v):
678758
torch.nn.functional.bilinear(a, c, b)))
679759

680760

761+
class TestGradCheckOverride(TestCase):
762+
"Test that wrappers work with gradcheck."
763+
def test_gradcheck(self):
764+
from torch.autograd import gradcheck
765+
766+
a = wrap(torch.tensor(5.0, dtype=torch.double))
767+
b = wrap(torch.tensor(6.0, dtype=torch.double))
768+
769+
a.requires_grad = True
770+
b.requires_grad = True
771+
772+
gradcheck(torch.add, (a, b), raise_exception=False)
773+
774+
total_used_attrs = a.used_attrs.union(b.used_attrs)
775+
total_used_calls = a.used_calls.union(b.used_calls)
776+
777+
# These attributes (and the functions below) may change
778+
# if the gradcheck implementation changes. It's best to
779+
# aim for attributes that may be commonly present on other
780+
# Tensor-likes.
781+
self.assertEqual(total_used_attrs, {
782+
'data',
783+
'dtype',
784+
'is_floating_point',
785+
'is_sparse',
786+
'layout',
787+
'nelement',
788+
'new_zeros',
789+
'requires_grad',
790+
'retain_grad',
791+
'size',
792+
'stride',
793+
})
794+
795+
self.assertEqual(total_used_calls, {
796+
torch.Tensor.new_zeros,
797+
torch.Tensor.size,
798+
torch.Tensor.is_floating_point,
799+
torch.Tensor.nelement,
800+
torch.Tensor.retain_grad,
801+
torch.Tensor.stride,
802+
torch.autograd.grad,
803+
torch.add,
804+
})
805+
806+
681807
if __name__ == '__main__':
682808
unittest.main()

torch/autograd/__init__.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from .gradcheck import gradcheck, gradgradcheck
1717
from .grad_mode import no_grad, enable_grad, set_grad_enabled
1818
from .anomaly_mode import detect_anomaly, set_detect_anomaly
19+
from ..overrides import has_torch_function, handle_torch_function
1920
from . import profiler
2021
from . import functional
2122

@@ -167,14 +168,27 @@ def grad(
167168
used when computing outputs (and therefore their grad is always zero)
168169
is an error. Defaults to ``False``.
169170
"""
171+
outputs = (outputs,) if isinstance(outputs, torch.Tensor) else tuple(outputs)
172+
inputs = (inputs,) if isinstance(inputs, torch.Tensor) else tuple(inputs)
173+
overridable_args = outputs + inputs
174+
if has_torch_function(overridable_args):
175+
return handle_torch_function(
176+
grad,
177+
overridable_args,
178+
outputs,
179+
inputs,
180+
grad_outputs=grad_outputs,
181+
retain_graph=retain_graph,
182+
create_graph=create_graph,
183+
only_inputs=only_inputs,
184+
allow_unused=allow_unused,
185+
)
186+
170187
if not only_inputs:
171188
warnings.warn("only_inputs argument is deprecated and is ignored now "
172189
"(defaults to True). To accumulate gradient for other "
173190
"parts of the graph, please use torch.autograd.backward.")
174191

175-
outputs = (outputs,) if isinstance(outputs, torch.Tensor) else tuple(outputs)
176-
inputs = (inputs,) if isinstance(inputs, torch.Tensor) else tuple(inputs)
177-
178192
if grad_outputs is None:
179193
grad_outputs = [None] * len(outputs)
180194
elif isinstance(grad_outputs, torch.Tensor):

torch/autograd/gradcheck.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from torch.types import _TensorOrTensors
33
from torch._six import container_abcs, istuple
44
import torch.testing
5+
from torch.overrides import is_tensor_like
56
from itertools import product
67
import warnings
78
from typing import Callable, Union, Optional
@@ -17,12 +18,12 @@ def zero_gradients(x):
1718

1819

1920
def make_jacobian(input, num_out):
20-
if isinstance(input, torch.Tensor):
21+
if is_tensor_like(input):
2122
if not input.is_floating_point() and not input.is_complex():
2223
return None
2324
if not input.requires_grad:
2425
return None
25-
return torch.zeros(input.nelement(), num_out, dtype=input.dtype)
26+
return input.new_zeros((input.nelement(), num_out), dtype=input.dtype, layout=torch.strided)
2627
elif isinstance(input, container_abcs.Iterable) and not isinstance(input, str):
2728
jacobians = list(filter(
2829
lambda x: x is not None, (make_jacobian(elem, num_out) for elem in input)))
@@ -34,7 +35,7 @@ def make_jacobian(input, num_out):
3435

3536

3637
def iter_tensors(x, only_requiring_grad=False):
37-
if isinstance(x, torch.Tensor):
38+
if is_tensor_like(x):
3839
if x.requires_grad or not only_requiring_grad:
3940
yield x
4041
elif isinstance(x, container_abcs.Iterable) and not isinstance(x, str):
@@ -253,13 +254,13 @@ def fail_test(msg):
253254
return False
254255

255256
tupled_inputs = _as_tuple(inputs)
256-
if any(t.is_sparse for t in tupled_inputs if isinstance(t, torch.Tensor)) and not check_sparse_nnz:
257+
if not check_sparse_nnz and any(t.is_sparse for t in tupled_inputs if isinstance(t, torch.Tensor)):
257258
return fail_test('gradcheck expects all tensor inputs are dense when check_sparse_nnz is set to False.')
258259

259260
# Make sure that gradients are saved for at least one input
260261
any_input_requiring_grad = False
261262
for idx, inp in enumerate(tupled_inputs):
262-
if isinstance(inp, torch.Tensor) and inp.requires_grad:
263+
if is_tensor_like(inp) and inp.requires_grad:
263264
if not (inp.dtype == torch.float64 or inp.dtype == torch.complex128):
264265
warnings.warn(
265266
'The {}th input requires gradient and '

torch/overrides.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1164,3 +1164,12 @@ def is_tensor_method_or_property(func: Callable) -> bool:
11641164
of ``torch.Tensor``.
11651165
"""
11661166
return func in get_tensor_methods() or func.__name__ == "__get__"
1167+
1168+
def is_tensor_like(inp):
1169+
"""
1170+
Returns ``True`` if the passed-in input is a tensor-like.
1171+
1172+
Currently, this occurs whenever there's a ``__torch_function__``
1173+
attribute on the input.
1174+
"""
1175+
return type(inp) is torch.Tensor or hasattr(inp, "__torch_function__")

0 commit comments

Comments
 (0)