Add gradcheck for forward AD by default and functional API#49099
Add gradcheck for forward AD by default and functional API#49099albanD wants to merge 41 commits intogh/albanD/67/basefrom
Conversation
[ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit 7f4a14a (more details on the Dr. CI page):
🕵️ 15 new failures recognized by patternsThe following CI failures do not appear to be due to upstream breakages:
|
[ghstack-poisoned]
This PR adds the option to check forward grad using gradcheck. The current logic is: - Forward grad is always checked - If the forward evaluation fails because an op is not implemented, the test is silently passing The goal is to make sure that all formulas that are added are properly tested without having to add a new test for each op. The final logic after the next PR that adds the remaining formulas is going to be: - Forward grad is always checked - Failure with not implemented op is an actual failure - Users should set `check_forward=False` if they explicitly don't want to test forward grads (which should not be the case internally). [ghstack-poisoned]
This PR adds the option to check forward grad using gradcheck. The current logic is: - Forward grad is always checked - If the forward evaluation fails because an op is not implemented, the test is silently passing The goal is to make sure that all formulas that are added are properly tested without having to add a new test for each op. The final logic after the next PR that adds the remaining formulas is going to be: - Forward grad is always checked - Failure with not implemented op is an actual failure - Users should set `check_forward=False` if they explicitly don't want to test forward grads (which should not be the case internally). [ghstack-poisoned]
This PR adds the option to check forward grad using gradcheck. The current logic is: - Forward grad is always checked - If the forward evaluation fails because an op is not implemented, the test is silently passing The goal is to make sure that all formulas that are added are properly tested without having to add a new test for each op. The final logic after the next PR that adds the remaining formulas is going to be: - Forward grad is always checked - Failure with not implemented op is an actual failure - Users should set `check_forward=False` if they explicitly don't want to test forward grads (which should not be the case internally). [ghstack-poisoned]
This PR adds the option to check forward grad using gradcheck. The current logic is: - Forward grad is always checked - If the forward evaluation fails because an op is not implemented, the test is silently passing The goal is to make sure that all formulas that are added are properly tested without having to add a new test for each op. The final logic after the next PR that adds the remaining formulas is going to be: - Forward grad is always checked - Failure with not implemented op is an actual failure - Users should set `check_forward=False` if they explicitly don't want to test forward grads (which should not be the case internally). [ghstack-poisoned]
This PR adds the option to check forward grad using gradcheck. The current logic is: - Forward grad is always checked - If the forward evaluation fails because an op is not implemented, the test is silently passing The goal is to make sure that all formulas that are added are properly tested without having to add a new test for each op. The final logic after the next PR that adds the remaining formulas is going to be: - Forward grad is always checked - Failure with not implemented op is an actual failure - Users should set `check_forward=False` if they explicitly don't want to test forward grads (which should not be the case internally). [ghstack-poisoned]
zou3519
left a comment
There was a problem hiding this comment.
Some questions and minor comments but this looks pretty reasonable to me
| pass | ||
|
|
||
| def _assertGradAndGradgradChecks(test_case, apply_fn, inputs): | ||
| def _assertGradAndGradgradChecks(test_case, apply_fn, inputs, check_forward=True): |
There was a problem hiding this comment.
This PR will merge-conflict with #49120 but we can figure out what order we want them to go in
There was a problem hiding this comment.
Thanks for the ref. The merge conflict should be fairly easy to handle indeed.
RFC: pytorch/rfcs#11 This PR adds the option to check forward grad using gradcheck. The current logic is: - Forward grad is always checked - If the forward evaluation fails because an op is not implemented, the test is silently passing The goal is to make sure that all formulas that are added are properly tested without having to add a new test for each op. The final logic after the next PR that adds the remaining formulas is going to be: - Forward grad is always checked - Failure with not implemented op is an actual failure - Users should set `check_forward=False` if they explicitly don't want to test forward grads (which should not be the case internally). [ghstack-poisoned]
RFC: pytorch/rfcs#11 This PR adds the option to check forward grad using gradcheck. The current logic is: - Forward grad is always checked - If the forward evaluation fails because an op is not implemented, the test is silently passing The goal is to make sure that all formulas that are added are properly tested without having to add a new test for each op. The final logic after the next PR that adds the remaining formulas is going to be: - Forward grad is always checked - Failure with not implemented op is an actual failure - Users should set `check_forward=False` if they explicitly don't want to test forward grads (which should not be the case internally). [ghstack-poisoned]
RFC: pytorch/rfcs#11 This PR adds the option to check forward grad using gradcheck. The current logic is: - Forward grad is always checked - If the forward evaluation fails because an op is not implemented, the test is silently passing The goal is to make sure that all formulas that are added are properly tested without having to add a new test for each op. The final logic after the next PR that adds the remaining formulas is going to be: - Forward grad is always checked - Failure with not implemented op is an actual failure - Users should set `check_forward=False` if they explicitly don't want to test forward grads (which should not be the case internally). [ghstack-poisoned]
RFC: pytorch/rfcs#11 This PR adds the option to check forward grad using gradcheck. The current logic is: - Forward grad is always checked - If the forward evaluation fails because an op is not implemented, the test is silently passing The goal is to make sure that all formulas that are added are properly tested without having to add a new test for each op. The final logic after the next PR that adds the remaining formulas is going to be: - Forward grad is always checked - Failure with not implemented op is an actual failure - Users should set `check_forward=False` if they explicitly don't want to test forward grads (which should not be the case internally). [ghstack-poisoned]
torch/autograd/gradcheck.py
Outdated
| for j, (a, n) in enumerate(zip(fw_analytical, numerical)): | ||
| if a.numel() != 0 or n.numel() != 0: | ||
| if not torch.allclose(a, n, rtol, atol): | ||
| return fail_test('Jacobian mismatch for output %d with respect to input %d,\n' | ||
| 'numerical:%s\nforward analytical:%s\n' % (i, j, n, a)) |
There was a problem hiding this comment.
nit: continues can reduce the indentation (and make that first condition a little easier to read), but it probably comes down to personal preference.
if a.numel() == 0 and n.numel() == 0:
continue
if torch.allclose(a, n, rtol, atol):
continue
return fail_test(...)
There was a problem hiding this comment.
Also, if a.numel() == 0 and n.numel() == 0, shouldn't we check that the shapes of the tensors are the same?
| the second derivative. | ||
| check_undefined_grad (bool, options): if True, check if undefined output grads | ||
| are supported and treated as zeros | ||
| check_forward(bool, options): if True, check the forward mode AD gradient |
There was a problem hiding this comment.
options -> optional (and we should state the default)
There was a problem hiding this comment.
Copy pasted the options from the one above... But I think you fixed this one in master? I'll update!
There was a problem hiding this comment.
My PR got reverted haha. I'll re-submit it later today
| exactly (default, 0.0) or be within this tolerance. | ||
| check_undefined_grad (bool, options): if True, check if undefined output grads | ||
| are supported and treated as zeros, for ``Tensor`` outputs. | ||
| check_forward (bool, optional): if True, check the forward mode AD gradient |
There was a problem hiding this comment.
We should state the default (False)
| # break both forward and backward mode links | ||
| out = a.detach() | ||
| out, _ = fwAD.unpack_dual(out) | ||
| return out |
There was a problem hiding this comment.
test_jvp_err_check_strict doesn't call jvp with fw_mode=True, so the unpack_dual line here doesn't actually do anything meaningful, right?
On a related note, it would be good to update all of the jvp tests to test with both fw_mode=True and fw_mode=False for comprehensiveness. I see the following functions:
- test_jvp_output
- test_jvp_scalar
- test_jvp_err_check
- test_jvp_err_check_strict
- test_jvp_create_graph (has already been updated to test fw_mode=True)
| self.assertIsNotNone(res[0].grad_fn) | ||
| self.assertIsNotNone(res[1].grad_fn) | ||
|
|
||
| gradcheck(lambda inp, v: autogradF.jvp(reducer, inp, v, create_graph=True), (inputs, v)) | ||
| gradgradcheck(lambda inp, v: autogradF.jvp(reducer, inp, v, create_graph=True), (inputs, v)) | ||
| res = autogradF.jvp(reducer, inputs, v, create_graph=True, fw_mode=False) | ||
| self._assert_same_struct(res[1], res[0]) | ||
| self.assertIsNotNone(res[0].grad_fn) | ||
| self.assertIsNotNone(res[1].grad_fn) |
There was a problem hiding this comment.
I'm confused -- isn't fw_mode=False the default right now?
Also nit:
for fw_mode in [True, False]:
res = autogradF.jvp(reducer, inputs, v, create_graph=True, fw_mode=False)
self._assert_same_struct(res[1], res[0])
self.assertIsNotNone(res[0].grad_fn)
self.assertIsNotNone(res[1].grad_fn)
There was a problem hiding this comment.
It is, this is left from an older version. My bad!
| # check_forward=False here as nested forward is not supported yet | ||
| gradcheck(lambda inp, v: autogradF.jvp(reducer, inp, v, create_graph=True), (inputs, v), check_forward=False) | ||
| gradgradcheck(lambda inp, v: autogradF.jvp(reducer, inp, v, create_graph=True), (inputs, v), check_forward=False) | ||
| gradcheck(lambda inp, v: autogradF.jvp(reducer, inp, v, create_graph=True, fw_mode=False), (inputs, v)) | ||
| gradgradcheck(lambda inp, v: autogradF.jvp(reducer, inp, v, create_graph=True, fw_mode=False), (inputs, v)) |
There was a problem hiding this comment.
Ditto -- isn't fw_mode=False the default? Am I missing an override somewhere?
zou3519
left a comment
There was a problem hiding this comment.
The logic in the gradcheck and jvp all looks fine to me. I have some high-level comments:
- I would have probably separated this PR into 2 pieces: one that updates gradcheck to check forward mode, and another that turns on forward mode AD in the jvp api, but it's OK to leave it as is since as the reviewer I figured out which tests correspond to which of these changes.
- API bikeshedding: in the APIs,
check_forwardorfw_modecan be ambiguous, especially when the user is familiar with terminology for "NN module forward pass". Naming it "check_forward_mode_ad" is longer but removes the ambiguity. - We should extend all of the jvp tests to test
fw_modeon True and False (I think only some were updated).
ghstack-source-id: cc5c7fe Pull Request resolved: pytorch#49099
RFC: pytorch/rfcs#11 This PR adds the option to check forward grad using gradcheck. The current logic is: - Forward grad is always checked - If the forward evaluation fails because an op is not implemented, the test is silently passing The goal is to make sure that all formulas that are added are properly tested without having to add a new test for each op. The final logic after the next PR that adds the remaining formulas is going to be: - Forward grad is always checked - Failure with not implemented op is an actual failure - Users should set `check_forward=False` if they explicitly don't want to test forward grads (which should not be the case internally). Differential Revision: [D25607502](https://our.internmc.facebook.com/intern/diff/D25607502) [ghstack-poisoned]
|
|
||
| return jacobian, reentrant, correct_grad_sizes, correct_grad_types | ||
|
|
||
| def get_analytical_jacobian_fw(fn, input, output): |
There was a problem hiding this comment.
To test the correctness of forward mode AD computing the whole Jacobian using finite differences and "analytical" is not needed. It's enough to compute a directional derivative for a random direction (Jacobian-vector product for a random tangent vector)
v = ... # random vector like input
analytical_vjp = pytorch_vjp(func, input, v)
numerical_vjp = finite_differences_vjp(func, input, v)
torch.allclose(analytical_vjp, numerical_vjp)def finite_differences_vjp(func, inp, v):
"""Jacobian of `func` at `inp` multiplied by `v`"""
# or using any other more accurate finite difference scheme
return (func(inp + eps * v) - func(inp)) / (eps * norm(v))So approximately computing JVP by finite difference takes only two function calls. Avoiding computing the whole Jacobian would speed up testing a lot.
There was a problem hiding this comment.
Yes, That is also mostly true as well for backward mode.
The point here is more that having the full Jacobian is much more useful as an error message :)
A fast version of gradcheck that uses that strategy you're talking about is in the pipe though and would also be applied to forward AD. But the full Jacobian version will still be kept to ensure good error messages in gradcheck.
RFC: pytorch/rfcs#11 This PR adds the option to check forward grad using gradcheck. The current logic is: - Forward grad is always checked - If the forward evaluation fails because an op is not implemented, the test is silently passing The goal is to make sure that all formulas that are added are properly tested without having to add a new test for each op. The final logic after the next PR that adds the remaining formulas is going to be: - Forward grad is always checked - Failure with not implemented op is an actual failure - Users should set `check_forward=False` if they explicitly don't want to test forward grads (which should not be the case internally). Differential Revision: [D25607502](https://our.internmc.facebook.com/intern/diff/D25607502) [ghstack-poisoned]
RFC: pytorch/rfcs#11 This PR adds the option to check forward grad using gradcheck. The current logic is: - Forward grad is always checked - If the forward evaluation fails because an op is not implemented, the test is silently passing The goal is to make sure that all formulas that are added are properly tested without having to add a new test for each op. The final logic after the next PR that adds the remaining formulas is going to be: - Forward grad is always checked - Failure with not implemented op is an actual failure - Users should set `check_forward=False` if they explicitly don't want to test forward grads (which should not be the case internally). Differential Revision: [D25607502](https://our.internmc.facebook.com/intern/diff/D25607502) [ghstack-poisoned]
RFC: pytorch/rfcs#11 This PR adds the option to check forward grad using gradcheck. The current logic is: - Forward grad is always checked - If the forward evaluation fails because an op is not implemented, the test is silently passing The goal is to make sure that all formulas that are added are properly tested without having to add a new test for each op. The final logic after the next PR that adds the remaining formulas is going to be: - Forward grad is always checked - Failure with not implemented op is an actual failure - Users should set `check_forward=False` if they explicitly don't want to test forward grads (which should not be the case internally). Differential Revision: [D25607502](https://our.internmc.facebook.com/intern/diff/D25607502) [ghstack-poisoned]
RFC: pytorch/rfcs#11 This PR adds the option to check forward grad using gradcheck. The current logic is: - Forward grad is always checked - If the forward evaluation fails because an op is not implemented, the test is silently passing The goal is to make sure that all formulas that are added are properly tested without having to add a new test for each op. The final logic after the next PR that adds the remaining formulas is going to be: - Forward grad is always checked - Failure with not implemented op is an actual failure - Users should set `check_forward=False` if they explicitly don't want to test forward grads (which should not be the case internally). Differential Revision: [D25607502](https://our.internmc.facebook.com/intern/diff/D25607502) [ghstack-poisoned]
RFC: pytorch/rfcs#11 This PR adds the option to check forward grad using gradcheck. The current logic is: - Forward grad is always checked - If the forward evaluation fails because an op is not implemented, the test is silently passing The goal is to make sure that all formulas that are added are properly tested without having to add a new test for each op. The final logic after the next PR that adds the remaining formulas is going to be: - Forward grad is always checked - Failure with not implemented op is an actual failure - Users should set `check_forward=False` if they explicitly don't want to test forward grads (which should not be the case internally). Differential Revision: [D25607502](https://our.internmc.facebook.com/intern/diff/D25607502) [ghstack-poisoned]
ghstack-source-id: c451fce Pull Request resolved: pytorch#49099
RFC: pytorch/rfcs#11 This PR adds the option to check forward grad using gradcheck. The current logic is: - Forward grad is always checked - If the forward evaluation fails because an op is not implemented, the test is silently passing The goal is to make sure that all formulas that are added are properly tested without having to add a new test for each op. The final logic after the next PR that adds the remaining formulas is going to be: - Forward grad is always checked - Failure with not implemented op is an actual failure - Users should set `check_forward=False` if they explicitly don't want to test forward grads (which should not be the case internally). Differential Revision: [D25607502](https://our.internmc.facebook.com/intern/diff/D25607502) [ghstack-poisoned]
RFC: pytorch/rfcs#11 This PR adds the option to check forward grad using gradcheck. The current logic is: - Forward grad is always checked - If the forward evaluation fails because an op is not implemented, the test is silently passing The goal is to make sure that all formulas that are added are properly tested without having to add a new test for each op. The final logic after the next PR that adds the remaining formulas is going to be: - Forward grad is always checked - Failure with not implemented op is an actual failure - Users should set `check_forward=False` if they explicitly don't want to test forward grads (which should not be the case internally). Differential Revision: [D25607502](https://our.internmc.facebook.com/intern/diff/D25607502) [ghstack-poisoned]
ghstack-source-id: 97dfcb7 Pull Request resolved: pytorch#49099
RFC: pytorch/rfcs#11 This PR adds the option to check forward grad using gradcheck. The current logic is: - Forward grad is always checked - If the forward evaluation fails because an op is not implemented, the test is silently passing The goal is to make sure that all formulas that are added are properly tested without having to add a new test for each op. The final logic after the next PR that adds the remaining formulas is going to be: - Forward grad is always checked - Failure with not implemented op is an actual failure - Users should set `check_forward=False` if they explicitly don't want to test forward grads (which should not be the case internally). Differential Revision: [D25607502](https://our.internmc.facebook.com/intern/diff/D25607502) [ghstack-poisoned]
|
This is being replaced by #57633 |
Stack from ghstack:
RFC: pytorch/rfcs#11
This PR adds the option to check forward grad using gradcheck. The current logic is:
The goal is to make sure that all formulas that are added are properly tested without having to add a new test for each op.
The final logic after the next PR that adds the remaining formulas is going to be:
check_forward=Falseif they explicitly don't want to test forward grads (which should not be the case internally).Differential Revision: D25607502