Update backward formula for torch.dot and add backward definition for torch.vdot#45074
Update backward formula for torch.dot and add backward definition for torch.vdot#45074anjali411 wants to merge 4 commits intogh/anjali411/57/basefrom
Conversation
… torch.vdot [ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit bc6a696 (more details on the Dr. CI page):
🕵️ 6 new failures recognized by patternsThe following CI failures do not appear to be due to upstream breakages:
|
| ('addr', (S, M), ((S,), (M,)), 'coef', (), (), (), ident, {'beta': 0.2, 'alpha': 0.6}), | ||
| ('addr', (), ((S,), (M,)), 'broadcast_lhs_coef', (), (), (), ident, {'beta': 0.2, 'alpha': 0.6}), | ||
| ('dot', (L,), ((L,),), '', (True,)), | ||
| ('vdot', (L,), ((L,),), '', (True,)), |
There was a problem hiding this comment.
This should be (False,) at the end (to turn off JIT autodiff testing)
There was a problem hiding this comment.
Alternatively, ('vdot', (L,), ((L,),), does the trick
There was a problem hiding this comment.
cc. @eellison torch.vdot is not supported by JIT autodiff right now.
There was a problem hiding this comment.
The question here is: Is it OK that autodiff doesn't support vdot? It's a new operator we added recently. Also, does the JIT still use the old autodiff pass?
There was a problem hiding this comment.
It should be fine if it's not supported by JIT autodiff - having an autodiff is an optimization which allows fusion to occur in training.
Yea it still uses the old autodiff pass. The way to define a backwards is in the symbolic_script.cpp file. Currently, there is really only autodiff coverage for pointwise ops because those are the ops that we codegen fusion for.
…inition for torch.vdot" TODO: Add R -> C tests in #44744 (blocked on some JIT changes) [ghstack-poisoned]
Codecov Report
@@ Coverage Diff @@
## gh/anjali411/57/base #45074 +/- ##
=======================================================
Coverage ? 67.85%
=======================================================
Files ? 384
Lines ? 50020
Branches ? 0
=======================================================
Hits ? 33940
Misses ? 16080
Partials ? 0 Continue to review full report at Codecov.
|
| Tensor correct_dtype_gradients(ScalarType self_st, Tensor gradient_result) { | ||
| if (!at::isComplexType(self_st) && gradient_result.is_complex()) { | ||
| // R -> C | ||
| return at::real(gradient_result); | ||
| } | ||
| return gradient_result; | ||
| } |
There was a problem hiding this comment.
"correct_dtype_gradients" seems like a very generic name (e.g. it could be mistaken for a function that handles float -> double dtype conversion). Also, it looks like this function will be used quite a lot in autograd formulas.
Tossing some quick ideas out there:
- "handle_r_to_c", "handle_real_to_complex"
- "maybe_real_part"
There was a problem hiding this comment.
It seems to me that at::real should noop on real tensors, similar to how at::conj is noop on real tensors too.
There was a problem hiding this comment.
@zou3519 makes sense.handle_real_to_complex should be clearer, I think.
@ezyang yeah I agree. The reason we disabled at::real for non-complex tensors before was because it would be weird to have real return a view for non-complex tensors and at::imag return a new tensor populated with zeros (which it was before). We can certainly only enable at::real for non-complex tensors if we want though.
There was a problem hiding this comment.
Sure, but now that at::conj is a no-op for real tensors, I think we should probably be OK with making at::real do the same as well. I don't care... too much about imag, I don't think it shows up in situations like this.
|
Besides Richard's comment, rest of the PR looks reasonable. |
…inition for torch.vdot" TODO: Add R -> C tests in #44744 (blocked on some JIT changes) [ghstack-poisoned]
zou3519
left a comment
There was a problem hiding this comment.
LGTM after the name change. The discussion on at::real being a no-op for real tensors seems orthogonal so I am approving to unblock.
…inition for torch.vdot" TODO: Add R -> C tests in #44744 (blocked on some JIT changes) Differential Revision: [D23975361](https://our.internmc.facebook.com/intern/diff/D23975361) [ghstack-poisoned]
|
@anjali411 merged this pull request in 18876b5. |
… torch.vdot (pytorch#45074) Summary: Pull Request resolved: pytorch#45074 TODO: Add R -> C tests in pytorch#44744 (blocked on some JIT changes) Test Plan: Imported from OSS Reviewed By: gchanan Differential Revision: D23975361 Pulled By: anjali411 fbshipit-source-id: 3512bd2962b588a198bc317673bd18cc96ac823f
Stack from ghstack:
TODO: Add R -> C tests in #44744 (blocked on some JIT changes)
Differential Revision: D23975361