torch.sgn for complex tensors#39955
Conversation
[ghstack-poisoned]
|
cc. @dylanbespalko for the Vec256 changes |
💊 CI failures summary and remediationsAs of commit 0fed8c8 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 95 times. |
| ['sigmoid', torch.sigmoid], | ||
| ['sigmoid_', torch.sigmoid_], | ||
| ['sign', torch.sign], | ||
| ['sgn', torch.sign], |
|
|
||
| Example:: | ||
|
|
||
| >>> x=torch.randn(4, dtype=torch.cfloat) |
There was a problem hiding this comment.
The example doesn't really demonstrate the function because of the random values. Maybe pick some angles, like 0, 45, 90 degrees?
| cos_angle = angle.cos() | ||
| sin_angle = angle.sin() | ||
| expected = cos_angle + 1j * sin_angle | ||
| self.assertEqual(x.sgn(), expected) |
There was a problem hiding this comment.
Can this instead assert that x.sgn has the same angle as x.angle and that the absolute value is everywhere one?
|
Added some notes. This also still needs to update torch/_overrides.py, docs/source/tensors.rst, docs/source/torch.rst, docs/source/name_inference.rst, aten/src/ATen/core/aten_interned_strings.h and tools/derivatives.yaml. Can this also be tested by TestTensorDeviceOps and TestTorchMathOps? |
[ghstack-poisoned]
ezyang
left a comment
There was a problem hiding this comment.
Don't compute normalized vector using trig functions
resolves #36323 by adding `torch.sgn` for complex tensors. `torch.sgn` returns `x/abs(x)` for `x != 0` and returns `0 + 0j` for `x==0` TODO: 1. add tests for backward (waiting on gradcheck PR for complex) [ghstack-poisoned]
mruberry
left a comment
There was a problem hiding this comment.
Minor doc nits but otherwise this looks good to me!
resolves #36323 by adding `torch.sgn` for complex tensors. `torch.sgn` returns `x/abs(x)` for `x != 0` and returns `0 + 0j` for `x==0` TODO: 1. add tests for backward (waiting on gradcheck PR for complex) [ghstack-poisoned]
albanD
left a comment
There was a problem hiding this comment.
I think this still needs the gradients for the real case to be tested, even if you postpone the check for the complex version.
I was thinking of waiting to merge this PR until #43208 is merged and add tests for sgn. |
resolves #36323 by adding `torch.sgn` for complex tensors. `torch.sgn` returns `x/abs(x)` for `x != 0` and returns `0 + 0j` for `x==0` This PR doesn't test the correctness of the gradients. It will be done as a part of auditing all the ops in future once we decide the autograd behavior (JAX vs TF) and add gradchek. Differential Revision: [D23460526](https://our.internmc.facebook.com/intern/diff/D23460526) [ghstack-poisoned]
resolves #36323 by adding `torch.sgn` for complex tensors. `torch.sgn` returns `x/abs(x)` for `x != 0` and returns `0 + 0j` for `x==0` This PR doesn't test the correctness of the gradients. It will be done as a part of auditing all the ops in future once we decide the autograd behavior (JAX vs TF) and add gradchek. Differential Revision: [D23460526](https://our.internmc.facebook.com/intern/diff/D23460526) [ghstack-poisoned]
resolves #36323 by adding `torch.sgn` for complex tensors. `torch.sgn` returns `x/abs(x)` for `x != 0` and returns `0 + 0j` for `x==0` also updates the backward definition of `torch.div`, `torch.abs` Differential Revision: [D23460526](https://our.internmc.facebook.com/intern/diff/D23460526) [ghstack-poisoned]
resolves #36323 by adding `torch.sgn` for complex tensors. `torch.sgn` returns `x/abs(x)` for `x != 0` and returns `0 + 0j` for `x==0` also updates the backward definition of `torch.div`, `torch.abs` Differential Revision: [D23460526](https://our.internmc.facebook.com/intern/diff/D23460526) [ghstack-poisoned]
resolves #36323 by adding `torch.sgn` for complex tensors. `torch.sgn` returns `x/abs(x)` for `x != 0` and returns `0 + 0j` for `x==0` also updates the backward definition of `torch.div`, `torch.abs` Differential Revision: [D23460526](https://our.internmc.facebook.com/intern/diff/D23460526) [ghstack-poisoned]
Codecov Report
@@ Coverage Diff @@
## gh/anjali411/34/base #39955 +/- ##
========================================================
- Coverage 67.86% 67.85% -0.01%
========================================================
Files 384 384
Lines 50026 50029 +3
========================================================
Hits 33948 33948
- Misses 16078 16081 +3
Continue to review full report at Codecov.
|
resolves #36323 by adding `torch.sgn` for complex tensors. `torch.sgn` returns `x/abs(x)` for `x != 0` and returns `0 + 0j` for `x==0` also updates the backward definition of `torch.div`, `torch.abs` Differential Revision: [D23460526](https://our.internmc.facebook.com/intern/diff/D23460526) [ghstack-poisoned]
resolves #36323 by adding `torch.sgn` for complex tensors. `torch.sgn` returns `x/abs(x)` for `x != 0` and returns `0 + 0j` for `x==0` also updates the backward definition of `torch.div`, `torch.abs` Differential Revision: [D23460526](https://our.internmc.facebook.com/intern/diff/D23460526) [ghstack-poisoned]
|
@anjali411 merged this pull request in 58b6ab6. |
ghstack-source-id: e4ab67a Pull Request resolved: pytorch/pytorch#39955
This PR disables autograd for all C -> C, R -> C functions which are not included in the whitelist `GRADIENT_IMPLEMENTED_FOR_COMPLEX`. In practice, there will be a RuntimeError during forward computation when the outputs are differentiable: ``` >>> x=torch.randn(4, 4, requires_grad=True, dtype=torch.cdouble) >>> x.pow(3) Traceback (most recent call last): File "<stdin>", line 1, in <module> RuntimeError: pow does not support automatic differentiation for outputs with complex dtype. ``` The implicit assumption here is that all the C -> R functions have correct backward definitions. So before merging this PR, the following functions must be tested and verified to have correct backward definitions: `torch.abs` (updated in #39955 ), `torch.angle`, `torch.norm`, `torch.irfft`, `torch.istft`. [ghstack-poisoned]
This PR disables autograd for all C -> C, R -> C functions which are not included in the whitelist `GRADIENT_IMPLEMENTED_FOR_COMPLEX`. In practice, there will be a RuntimeError during forward computation when the outputs are differentiable: ``` >>> x=torch.randn(4, 4, requires_grad=True, dtype=torch.cdouble) >>> x.pow(3) Traceback (most recent call last): File "<stdin>", line 1, in <module> RuntimeError: pow does not support automatic differentiation for outputs with complex dtype. ``` The implicit assumption here is that all the C -> R functions have correct backward definitions. So before merging this PR, the following functions must be tested and verified to have correct backward definitions: `torch.abs` (updated in #39955 ), `torch.angle`, `torch.norm`, `torch.irfft`, `torch.istft`. [ghstack-poisoned]
This PR disables autograd for all C -> C, R -> C functions which are not included in the whitelist `GRADIENT_IMPLEMENTED_FOR_COMPLEX`. In practice, there will be a RuntimeError during forward computation when the outputs are differentiable: ``` >>> x=torch.randn(4, 4, requires_grad=True, dtype=torch.cdouble) >>> x.pow(3) Traceback (most recent call last): File "<stdin>", line 1, in <module> RuntimeError: pow does not support automatic differentiation for outputs with complex dtype. ``` The implicit assumption here is that all the C -> R functions have correct backward definitions. So before merging this PR, the following functions must be tested and verified to have correct backward definitions: `torch.abs` (updated in #39955 ), `torch.angle`, `torch.norm`, `torch.irfft`, `torch.istft`. [ghstack-poisoned]
This PR disables autograd for all C -> C, R -> C functions which are not included in the whitelist `GRADIENT_IMPLEMENTED_FOR_COMPLEX`. In practice, there will be a RuntimeError during forward computation when the outputs are differentiable: ``` >>> x=torch.randn(4, 4, requires_grad=True, dtype=torch.cdouble) >>> x.pow(3) Traceback (most recent call last): File "<stdin>", line 1, in <module> RuntimeError: pow does not support automatic differentiation for outputs with complex dtype. ``` The implicit assumption here is that all the C -> R functions have correct backward definitions. So before merging this PR, the following functions must be tested and verified to have correct backward definitions: `torch.abs` (updated in #39955 ), `torch.angle`, `torch.norm`, `torch.irfft`, `torch.istft`. Differential Revision: [D23998156](https://our.internmc.facebook.com/intern/diff/D23998156) [ghstack-poisoned]
This PR disables autograd for all C -> C, R -> C functions which are not included in the whitelist `GRADIENT_IMPLEMENTED_FOR_COMPLEX`. In practice, there will be a RuntimeError during forward computation when the outputs are differentiable: ``` >>> x=torch.randn(4, 4, requires_grad=True, dtype=torch.cdouble) >>> x.pow(3) Traceback (most recent call last): File "<stdin>", line 1, in <module> RuntimeError: pow does not support automatic differentiation for outputs with complex dtype. ``` The implicit assumption here is that all the C -> R functions have correct backward definitions. So before merging this PR, the following functions must be tested and verified to have correct backward definitions: `torch.abs` (updated in #39955 ), `torch.angle`, `torch.norm`, `torch.irfft`, `torch.istft`. Differential Revision: [D23998156](https://our.internmc.facebook.com/intern/diff/D23998156) [ghstack-poisoned]
Summary: Pull Request resolved: #45461 This PR disables autograd for all C -> C, R -> C functions which are not included in the whitelist `GRADIENT_IMPLEMENTED_FOR_COMPLEX`. In practice, there will be a RuntimeError during forward computation when the outputs are differentiable: ``` >>> x=torch.randn(4, 4, requires_grad=True, dtype=torch.cdouble) >>> x.pow(3) Traceback (most recent call last): File "<stdin>", line 1, in <module> RuntimeError: pow does not support automatic differentiation for outputs with complex dtype. ``` The implicit assumption here is that all the C -> R functions have correct backward definitions. So before merging this PR, the following functions must be tested and verified to have correct backward definitions: `torch.abs` (updated in #39955 ), `torch.angle`, `torch.norm`, `torch.irfft`, `torch.istft`. Test Plan: Imported from OSS Reviewed By: malfet Differential Revision: D23998156 Pulled By: anjali411 fbshipit-source-id: 370eb07fe56ac84dd8e2233ef7bf3a3eb8aeb179
Summary: Pull Request resolved: pytorch#39955 resolves pytorch#36323 by adding `torch.sgn` for complex tensors. `torch.sgn` returns `x/abs(x)` for `x != 0` and returns `0 + 0j` for `x==0` This PR doesn't test the correctness of the gradients. It will be done as a part of auditing all the ops in future once we decide the autograd behavior (JAX vs TF) and add gradchek. Test Plan: Imported from OSS Reviewed By: mruberry Differential Revision: D23460526 Pulled By: anjali411 fbshipit-source-id: 70fc4e14e4d66196e27cf188e0422a335fc42f92
Summary: Pull Request resolved: pytorch#45461 This PR disables autograd for all C -> C, R -> C functions which are not included in the whitelist `GRADIENT_IMPLEMENTED_FOR_COMPLEX`. In practice, there will be a RuntimeError during forward computation when the outputs are differentiable: ``` >>> x=torch.randn(4, 4, requires_grad=True, dtype=torch.cdouble) >>> x.pow(3) Traceback (most recent call last): File "<stdin>", line 1, in <module> RuntimeError: pow does not support automatic differentiation for outputs with complex dtype. ``` The implicit assumption here is that all the C -> R functions have correct backward definitions. So before merging this PR, the following functions must be tested and verified to have correct backward definitions: `torch.abs` (updated in pytorch#39955 ), `torch.angle`, `torch.norm`, `torch.irfft`, `torch.istft`. Test Plan: Imported from OSS Reviewed By: malfet Differential Revision: D23998156 Pulled By: anjali411 fbshipit-source-id: 370eb07fe56ac84dd8e2233ef7bf3a3eb8aeb179
Stack from ghstack:
resolves #36323 by adding
torch.sgnfor complex tensors.torch.sgnreturnsx/abs(x)forx != 0and returns0 + 0jforx==0also updates the backward definition of
torch.div,torch.absDifferential Revision: D23460526