Skip to content

Commit 7c6878f

Browse files
committed
Revert "Allow Tensor-likes in torch.autograd.gradcheck (#43877)"
This reverts commit f9a0d0c. [ghstack-poisoned]
1 parent 77cc7d1 commit 7c6878f

4 files changed

Lines changed: 21 additions & 174 deletions

File tree

test/test_overrides.py

Lines changed: 16 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -644,107 +644,27 @@ def test(self):
644644

645645
generate_tensor_like_override_tests(TestTorchFunctionOverride)
646646

647-
class Wrapper:
648-
"Basic data container that knows how to unwrap itself"
649-
def __init__(self, data):
650-
self.__dict__["_data"] = data
651-
self.__dict__["used_attrs"] = set()
652-
self.__dict__["used_calls"] = set()
653-
654-
def __getattr__(self, name):
655-
if name in self.__dict__:
656-
return self.__dict__[name]
657-
self.used_attrs.add(name)
658-
659-
val = getattr(self._data, name)
660-
661-
# If it's a method
662-
if callable(val):
663-
c = getattr(type(self._data), name)
664-
# Don't append self to args if classmethod/staticmethod
665-
if c is val:
666-
return lambda *a, **kw: wrap(self.__torch_function__(c, (Wrapper,), args=a, kwargs=kw))
667-
# Otherwise append self to args
668-
return lambda *a, **kw: wrap(self.__torch_function__(c, (Wrapper,), args=(self,) + a, kwargs=kw))
669-
670-
return wrap(val)
671-
672-
def __setattr__(self, name, value):
673-
if name in self.__dict__:
674-
self.__dict__[name] = value
675-
676-
self.used_attrs.add(name)
677-
setattr(self._data, name, unwrap(value))
678-
679-
def __setitem__(self, key, value):
680-
self._data[unwrap(key)] = unwrap(value)
681-
682-
def __getitem__(self, key):
683-
return wrap(self._data[unwrap(key)])
684-
685-
def __torch_function__(self, func, types, args=(), kwargs=None):
686-
if kwargs is None:
687-
kwargs = {}
688-
self.used_calls.add(func)
689-
args = unwrap(tuple(args))
690-
kwargs = {k: unwrap(v) for k, v in kwargs.items()}
691-
692-
return wrap(func(*args, **kwargs))
693-
694-
def __add__(self, other):
695-
return self.__torch_function__(torch.add, (Wrapper,), (self, other))
696-
697-
def __sub__(self, other):
698-
return self.__torch_function__(torch.sub, (Wrapper,), (self, other))
699-
700-
def __truediv__(self, other):
701-
return self.__torch_function__(torch.true_divide, (Wrapper,), (self, other))
702-
703-
def __floordiv__(self, other):
704-
return self.__torch_function__(torch.floor_divide, (Wrapper,), (self, other))
705-
706-
def __ge__(self, other):
707-
return self.__torch_function__(torch.ge, (Wrapper,), (self, other))
708-
709-
def __gt__(self, other):
710-
return self.__torch_function__(torch.gt, (Wrapper,), (self, other))
711-
712-
def __lt__(self, other):
713-
return self.__torch_function__(torch.lt, (Wrapper,), (self, other))
714-
715-
def __le__(self, other):
716-
return self.__torch_function__(torch.le, (Wrapper,), (self, other))
717-
718-
def __eq__(self, other):
719-
return self.__torch_function__(torch.eq, (Wrapper,), (self, other))
720-
721-
def __ne__(self, other):
722-
return self.__torch_function__(torch.ne, (Wrapper,), (self, other))
723-
724-
def __bool__(self):
725-
return self.__torch_function__(torch.Tensor.__bool__, (Wrapper,), (self,))
726-
727-
def __int__(self):
728-
return self.__torch_function__(torch.Tensor.__int__, (Wrapper,), (self,))
729-
647+
class TestEinsumOverride(TestCase):
648+
"Regression test for gh-38479"
649+
def test_wrapper(self):
650+
class Wrapper():
651+
"Basic data container that knows how to unwrap itself"
652+
def __init__(self, data):
653+
self.data = data
730654

731-
# unwrap inputs if necessary
732-
def unwrap(v):
733-
if type(v) in {tuple, list}:
734-
return type(v)(unwrap(vi) for vi in v)
655+
def __torch_function__(self, func, types, args=(), kwargs=None):
656+
if kwargs is None:
657+
kwargs = {}
735658

736-
return v._data if isinstance(v, Wrapper) else v
659+
# unwrap inputs if necessary
660+
def unwrap(v):
661+
return v.data if isinstance(v, Wrapper) else v
737662

738-
# wrap inputs if necessary
739-
def wrap(v):
740-
if type(v) in {tuple, list}:
741-
return type(v)(wrap(vi) for vi in v)
663+
args = map(unwrap, args)
664+
kwargs = {k: unwrap(v) for k, v in kwargs.items()}
742665

743-
return Wrapper(v) if isinstance(v, torch.Tensor) else v
666+
return func(*args, **kwargs)
744667

745-
class TestEinsumOverride(TestCase):
746-
"Regression test for gh-38479"
747-
def test_wrapper(self):
748668
x = Wrapper(torch.randn(5))
749669
y = Wrapper(torch.randn(4))
750670
self.assertTrue(torch.allclose(torch.einsum('i,j->ij', x, y),
@@ -758,51 +678,5 @@ def test_wrapper(self):
758678
torch.nn.functional.bilinear(a, c, b)))
759679

760680

761-
class TestGradCheckOverride(TestCase):
762-
"Test that wrappers work with gradcheck."
763-
def test_gradcheck(self):
764-
from torch.autograd import gradcheck
765-
766-
a = wrap(torch.tensor(5.0, dtype=torch.double))
767-
b = wrap(torch.tensor(6.0, dtype=torch.double))
768-
769-
a.requires_grad = True
770-
b.requires_grad = True
771-
772-
gradcheck(torch.add, (a, b), raise_exception=False)
773-
774-
total_used_attrs = a.used_attrs.union(b.used_attrs)
775-
total_used_calls = a.used_calls.union(b.used_calls)
776-
777-
# These attributes (and the functions below) may change
778-
# if the gradcheck implementation changes. It's best to
779-
# aim for attributes that may be commonly present on other
780-
# Tensor-likes.
781-
self.assertEqual(total_used_attrs, {
782-
'data',
783-
'dtype',
784-
'is_floating_point',
785-
'is_sparse',
786-
'layout',
787-
'nelement',
788-
'new_zeros',
789-
'requires_grad',
790-
'retain_grad',
791-
'size',
792-
'stride',
793-
})
794-
795-
self.assertEqual(total_used_calls, {
796-
torch.Tensor.new_zeros,
797-
torch.Tensor.size,
798-
torch.Tensor.is_floating_point,
799-
torch.Tensor.nelement,
800-
torch.Tensor.retain_grad,
801-
torch.Tensor.stride,
802-
torch.autograd.grad,
803-
torch.add,
804-
})
805-
806-
807681
if __name__ == '__main__':
808682
unittest.main()

torch/autograd/__init__.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from .gradcheck import gradcheck, gradgradcheck
1818
from .grad_mode import no_grad, enable_grad, set_grad_enabled
1919
from .anomaly_mode import detect_anomaly, set_detect_anomaly
20-
from ..overrides import has_torch_function, handle_torch_function
2120
from . import profiler
2221
from . import functional
2322

@@ -172,22 +171,6 @@ def grad(
172171
used when computing outputs (and therefore their grad is always zero)
173172
is an error. Defaults to ``False``.
174173
"""
175-
outputs = (outputs,) if isinstance(outputs, torch.Tensor) else tuple(outputs)
176-
inputs = (inputs,) if isinstance(inputs, torch.Tensor) else tuple(inputs)
177-
overridable_args = outputs + inputs
178-
if has_torch_function(overridable_args):
179-
return handle_torch_function(
180-
grad,
181-
overridable_args,
182-
outputs,
183-
inputs,
184-
grad_outputs=grad_outputs,
185-
retain_graph=retain_graph,
186-
create_graph=create_graph,
187-
only_inputs=only_inputs,
188-
allow_unused=allow_unused,
189-
)
190-
191174
if not only_inputs:
192175
warnings.warn("only_inputs argument is deprecated and is ignored now "
193176
"(defaults to True). To accumulate gradient for other "

torch/autograd/gradcheck.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from torch.types import _TensorOrTensors
33
from torch._six import container_abcs, istuple
44
import torch.testing
5-
from torch.overrides import is_tensor_like
65
from itertools import product
76
import warnings
87
from typing import Callable, Union, Optional
@@ -18,12 +17,12 @@ def zero_gradients(x):
1817

1918

2019
def make_jacobian(input, num_out):
21-
if is_tensor_like(input):
20+
if isinstance(input, torch.Tensor):
2221
if not input.is_floating_point() and not input.is_complex():
2322
return None
2423
if not input.requires_grad:
2524
return None
26-
return input.new_zeros((input.nelement(), num_out), dtype=input.dtype, layout=torch.strided)
25+
return torch.zeros(input.nelement(), num_out, dtype=input.dtype)
2726
elif isinstance(input, container_abcs.Iterable) and not isinstance(input, str):
2827
jacobians = list(filter(
2928
lambda x: x is not None, (make_jacobian(elem, num_out) for elem in input)))
@@ -35,7 +34,7 @@ def make_jacobian(input, num_out):
3534

3635

3736
def iter_tensors(x, only_requiring_grad=False):
38-
if is_tensor_like(x):
37+
if isinstance(x, torch.Tensor):
3938
if x.requires_grad or not only_requiring_grad:
4039
yield x
4140
elif isinstance(x, container_abcs.Iterable) and not isinstance(x, str):
@@ -254,13 +253,13 @@ def fail_test(msg):
254253
return False
255254

256255
tupled_inputs = _as_tuple(inputs)
257-
if not check_sparse_nnz and any(t.is_sparse for t in tupled_inputs if isinstance(t, torch.Tensor)):
256+
if any(t.is_sparse for t in tupled_inputs if isinstance(t, torch.Tensor)) and not check_sparse_nnz:
258257
return fail_test('gradcheck expects all tensor inputs are dense when check_sparse_nnz is set to False.')
259258

260259
# Make sure that gradients are saved for at least one input
261260
any_input_requiring_grad = False
262261
for idx, inp in enumerate(tupled_inputs):
263-
if is_tensor_like(inp) and inp.requires_grad:
262+
if isinstance(inp, torch.Tensor) and inp.requires_grad:
264263
if not (inp.dtype == torch.float64 or inp.dtype == torch.complex128):
265264
warnings.warn(
266265
'The {}th input requires gradient and '

torch/overrides.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1164,12 +1164,3 @@ def is_tensor_method_or_property(func: Callable) -> bool:
11641164
of ``torch.Tensor``.
11651165
"""
11661166
return func in get_tensor_methods() or func.__name__ == "__get__"
1167-
1168-
def is_tensor_like(inp):
1169-
"""
1170-
Returns ``True`` if the passed-in input is a tensor-like.
1171-
1172-
Currently, this occurs whenever there's a ``__torch_function__``
1173-
attribute on the input.
1174-
"""
1175-
return type(inp) is torch.Tensor or hasattr(inp, "__torch_function__")

0 commit comments

Comments
 (0)