Conversation
[ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit 10a4089 (more details on the Dr. CI page):
🕵️ 1 new failure recognized by patternsThe following CI failures do not appear to be due to upstream breakages:
|
Note that this PR implements formulas only for ops that are supported by OpInfo. [ghstack-poisoned]
Note that this PR implements formulas only for ops that are supported by OpInfo. [ghstack-poisoned]
Note that this PR implements formulas only for ops that are supported by OpInfo. [ghstack-poisoned]
ghstack-source-id: f9da430 Pull Request resolved: pytorch#57768
Note that this PR implements formulas only for ops that are supported by OpInfo. [ghstack-poisoned]
Note that this PR implements formulas only for ops that are supported by OpInfo. [ghstack-poisoned]
ghstack-source-id: 8b9396e Pull Request resolved: pytorch#57768
Note that this PR implements formulas only for ops that are supported by OpInfo. [ghstack-poisoned]
ghstack-source-id: e313cea Pull Request resolved: pytorch#57768
Note that this PR implements formulas only for ops that are supported by OpInfo. [ghstack-poisoned]
Note that this PR implements formulas only for ops that are supported by OpInfo. [ghstack-poisoned]
|
@albanD has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
Note that this PR implements formulas only for ops that are supported by OpInfo. Slow gradcheck also passes for this PR and can be found here: #57976 Differential Revision: [D28387766](https://our.internmc.facebook.com/intern/diff/D28387766) [ghstack-poisoned]
Note that this PR implements formulas only for ops that are supported by OpInfo. Slow gradcheck also passes for this PR and can be found here: #57976 Differential Revision: [D28387766](https://our.internmc.facebook.com/intern/diff/D28387766) [ghstack-poisoned]
Note that this PR implements formulas only for ops that are supported by OpInfo. Slow gradcheck also passes for this PR and can be found here: #57976 Differential Revision: [D28387766](https://our.internmc.facebook.com/intern/diff/D28387766) [ghstack-poisoned]
| self: handle_r_to_c(self.scalar_type(), grad) | ||
| tensor1: handle_r_to_c(tensor1.scalar_type(), grad * (value / tensor2).conj()) | ||
| tensor2: handle_r_to_c(tensor2.scalar_type(), -grad * (value * tensor1 / (tensor2 * tensor2)).conj()) | ||
| result: self_t + maybe_multiply(tensor1_t / tensor2_p, value) - maybe_multiply(tensor2_t * (tensor1_p / tensor2_p) / tensor2_p, value) |
There was a problem hiding this comment.
(no action required) Could you actually "auto-elementwise" this and other pointwise operations? It looks like the formula (for the real case at least) is just the backward formula for self + backward formula for tensor 1 + backward formula for tensor2 while replacing all the grads with the correct tangents.
There was a problem hiding this comment.
Yes we could do it here. Will add it if it shows up again.
| self: maybe_multiply(grad, beta.conj()) | ||
| mat1: mm_mat1_backward(grad, mat2, mat1.sizes(), mat1.strides(), alpha) | ||
| mat2: mm_mat2_backward(grad, mat1, mat2.sizes(), mat2.strides(), alpha) | ||
| result: maybe_multiply(self_t, beta) + maybe_multiply(mat1_t.mm(mat2_p), alpha) + maybe_multiply(mat1_p.mm(mat2_t), alpha) |
There was a problem hiding this comment.
It's interesting to note that this is just maybe_multiply(self_t, beta) added with maybe_multiply( formula_for_mm , alpha ). Are there any chances we would want to dedup code between this and the mm formula in the future?
There was a problem hiding this comment.
The main problem with such formula is that they are not element-wise. So adding the formulas won't work.
And they are affine (not linear) and so we would need to provide some arguments to a smarter auto_affine to handle this. Which feels like a dangerous step to take.
There was a problem hiding this comment.
Yeah, I agree the design for this would be tricky.
zou3519
left a comment
There was a problem hiding this comment.
the formulas lgtm from a real numbers perspective, but I am not sure how to derive them for complex numbers
There was a problem hiding this comment.
Offline Alban walked me through a derivation of complex forward-mode AD derivative for torch.sin and torch.conj and those helped me understand enough to derive the complex formulas as well.
NB: we should update the one example above that wasn't updated for this PR, but other than that things lgtm
Note that this PR implements formulas only for ops that are supported by OpInfo. Slow gradcheck also passes for this PR and can be found here: #57976 Differential Revision: [D28387766](https://our.internmc.facebook.com/intern/diff/D28387766) [ghstack-poisoned]
|
@albanD has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
Summary: Pull Request resolved: pytorch#57768 Note that this PR implements formulas only for ops that are supported by OpInfo. Test Plan: Imported from OSS Reviewed By: zou3519, malfet Differential Revision: D28387766 Pulled By: albanD fbshipit-source-id: b4ba1cf1ac1dfd46cdd889385c9c2d5df3cf7a71
Note that this PR implements formulas only for ops that are supported by OpInfo.
Slow gradcheck also passes for this PR and can be found here: #57976
Stack from ghstack:
Differential Revision: D28387766