Skip to content

Commit 647282c

Browse files
albanDfacebook-github-bot
authored andcommitted
Add forward AD gradcheck (#57633)
Summary: Pull Request resolved: #57633 Test Plan: Imported from OSS Reviewed By: agolynski Differential Revision: D28387765 Pulled By: albanD fbshipit-source-id: ed15049b5bdacca54f775b50ef166d540ba0b847
1 parent bc30c31 commit 647282c

4 files changed

Lines changed: 288 additions & 71 deletions

File tree

test/test_autograd.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3930,9 +3930,26 @@ def check(fast_mode):
39303930
x = x.expand((2, 2))
39313931
with self.assertRaisesRegex(RuntimeError, 'The 0th input has a dimension with stride 0'):
39323932
gradcheck(lambda x: x, (x,), raise_exception=False, fast_mode=fast_mode)
3933+
39333934
check(fast_mode=True)
39343935
check(fast_mode=False)
39353936

3937+
@unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled")
3938+
def test_gradcheck_validates_input_mkldnn(self):
3939+
# when mkldnn inputs, forward mode testing is not allowed
3940+
# Update tolerances below to make sure the gradient match even in single precision floats
3941+
# Use the warning assert to hide the float32 warning
3942+
x = torch.ones(1).to_mkldnn().requires_grad_()
3943+
with self.assertWarnsRegex(UserWarning, "Input #0 requires gradient and is not a double precision"):
3944+
with self.assertRaisesRegex(ValueError, 'MKLDNN inputs are not support for forward AD gradcheck.'):
3945+
gradcheck(lambda x: x.to_dense(), (x,), raise_exception=False, fast_mode=False, check_forward_ad=True,
3946+
atol=1e-1, rtol=1e-1)
3947+
3948+
with self.assertWarnsRegex(UserWarning, "Input #0 requires gradient and is not a double precision"):
3949+
with self.assertRaisesRegex(ValueError, 'MKLDNN inputs are not support for forward AD gradcheck.'):
3950+
gradcheck(lambda x: x.to_dense(), (x,), raise_exception=False, fast_mode=True, check_forward_ad=True,
3951+
atol=1e-1, rtol=1e-1)
3952+
39363953
@unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled")
39373954
def test_gradcheck_test_outputs(self):
39383955
def check(fast_mode):
@@ -4223,6 +4240,49 @@ def fn2(x):
42234240
check(fast_mode=True)
42244241
check(fast_mode=False)
42254242

4243+
def test_gradcheck_forward_ad(self):
4244+
def fn(x, y):
4245+
return x + y, y
4246+
4247+
def bad_fn(x, y):
4248+
# Hacky way to check if we're currently inside a forward ad level
4249+
is_running_forward_ad = fwAD._current_level >= 0
4250+
4251+
if is_running_forward_ad:
4252+
y_p, y_d = fwAD.unpack_dual(y)
4253+
y = fwAD.make_dual(y_p, y_d * 1.1)
4254+
4255+
return x + y, y
4256+
4257+
err_msg = "Jacobian computed with forward mode mismatch for output 0 with respect to input 1"
4258+
4259+
for fast_mode in [True, False]:
4260+
# Test for all inputs and outputs being real
4261+
x = torch.rand(2, dtype=torch.double, requires_grad=True)
4262+
y = torch.rand(2, dtype=torch.double, requires_grad=True)
4263+
4264+
gradcheck(fn, (x, y), check_forward_ad=True, fast_mode=fast_mode)
4265+
with self.assertRaisesRegex(RuntimeError, err_msg):
4266+
gradcheck(bad_fn, (x, y), check_forward_ad=True, fast_mode=fast_mode)
4267+
4268+
def basic_mul(x):
4269+
return torch.view_as_real(x * 1j)
4270+
gradcheck(basic_mul, x, check_forward_ad=True, fast_mode=fast_mode)
4271+
4272+
# Test for one input and one output being complex
4273+
x = torch.rand(2, dtype=torch.cdouble, requires_grad=True)
4274+
4275+
gradcheck(fn, (x, y), check_forward_ad=True, fast_mode=fast_mode)
4276+
with self.assertRaisesRegex(RuntimeError, err_msg):
4277+
gradcheck(bad_fn, (x, y), check_forward_ad=True, fast_mode=fast_mode)
4278+
4279+
# Test for all inputs and outputs being complex
4280+
y = torch.rand(2, dtype=torch.cdouble, requires_grad=True)
4281+
4282+
gradcheck(fn, (x, y), check_forward_ad=True, fast_mode=fast_mode)
4283+
with self.assertRaisesRegex(RuntimeError, err_msg):
4284+
gradcheck(bad_fn, (x, y), check_forward_ad=True, fast_mode=fast_mode)
4285+
42264286
def test_version_counter(self):
42274287
x = torch.randn(1, 2)
42284288

test/test_overrides.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -806,7 +806,6 @@ def run_test(fast_mode):
806806
# Tensor-likes.
807807
expected_used_attrs = {
808808
'data',
809-
'device',
810809
'dtype',
811810
'is_floating_point',
812811
'is_sparse',
@@ -820,6 +819,7 @@ def run_test(fast_mode):
820819
}
821820
if fast_mode:
822821
expected_used_attrs.add('is_complex')
822+
expected_used_attrs.add('device')
823823
self.assertEqual(expected_used_attrs, total_used_attrs)
824824

825825
expected_used_calls = {
@@ -833,8 +833,9 @@ def run_test(fast_mode):
833833
torch.add,
834834
}
835835
if fast_mode:
836-
expected_used_attrs.add(torch.Tensor.is_complex)
836+
expected_used_calls.add(torch.Tensor.is_complex)
837837
self.assertEqual(expected_used_calls, total_used_calls)
838+
run_test(fast_mode=True)
838839
run_test(fast_mode=False)
839840

840841
class TestNamedTuple(TestCase):

tools/autograd/derivatives.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -845,7 +845,7 @@
845845
- name: mul.Tensor(Tensor self, Tensor other) -> Tensor
846846
self: mul_tensor_backward(grad, other, self.scalar_type())
847847
other: mul_tensor_backward(grad, self, other.scalar_type())
848-
result: other_t * self_p.conj() + self_t * other_p.conj()
848+
result: other_t * self_p + self_t * other_p
849849

850850
- name: mul_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
851851
self: mul_tensor_backward(grad, other, self.scalar_type())
@@ -1272,9 +1272,11 @@
12721272

12731273
- name: view_as_real(Tensor(a) self) -> Tensor(a)
12741274
self: at::view_as_complex(grad.contiguous()) # gx0 + 1j * gx1
1275+
result: at::view_as_real(self_t)
12751276

12761277
- name: view_as_complex(Tensor(a) self) -> Tensor(a)
12771278
self: at::view_as_real(grad.contiguous()) # [gx, gy]
1279+
result: at::view_as_complex(self_t)
12781280

12791281
- name: _s_where(Tensor condition, Tensor self, Tensor other) -> Tensor
12801282
condition: non_differentiable

0 commit comments

Comments
 (0)