Add allowlist for complex backward#45461
Add allowlist for complex backward#45461anjali411 wants to merge 7 commits intogh/anjali411/59/basefrom
Conversation
[ghstack-poisoned]
| 'eq_', 'ne_', 'add', '__radd__', 'sum', '_conj', 'sin', 'cos', 'mul', 'sinh', | ||
| 'cosh', '__rmul__', 'sgn', 'view_as_real', 'real', 'imag', 'asin', 'acos', 'sub', | ||
| 'div', 'cat', 'view_as_complex', 'neg', 'complex', 'select', '_s_where', 'as_strided', | ||
| '_fft_with_size' |
There was a problem hiding this comment.
There was a problem hiding this comment.
_fft_with_size doesn't use complex, it uses (..., 2) shaped real tensors.
There was a problem hiding this comment.
fft failures are real, though.
💊 CI failures summary and remediationsAs of commit 5d2c5fe (more details on the Dr. CI page):
❄️ 1 failure tentatively classified as flakybut reruns have not yet been triggered to confirm:
|
This PR disables autograd for all C -> C, R -> C functions which are not included in the whitelist `GRADIENT_IMPLEMENTED_FOR_COMPLEX`. In practice, there will be a RuntimeError during forward computation when the outputs are differentiable: ``` >>> x=torch.randn(4, 4, requires_grad=True, dtype=torch.cdouble) >>> x.pow(3) Traceback (most recent call last): File "<stdin>", line 1, in <module> RuntimeError: pow does not support automatic differentiation for outputs with complex dtype. ``` The implicit assumption here is that all the C -> R functions have correct backward definitions. So before merging this PR, the following functions must be tested and verified to have correct backward definitions: `torch.abs`, `torch.angle`, `torch.norm`, `torch.irfft`, `torch.istft`. [ghstack-poisoned]
albanD
left a comment
There was a problem hiding this comment.
I did not check the C->R functions nor the fact that the functions in the list are actually properly implemented.
The rest of the codegen looks good except the TensorList support.
Also it would be nice to show what the new generated code looks like for a sample function.
| if base_name in GRADIENT_IMPLEMENTED_FOR_COMPLEX: | ||
| return body | ||
| for arg in differentiable_outputs: | ||
| if arg['type'] == 'Tensor': |
There was a problem hiding this comment.
What about TensorList?
In particular functions like unbind() will return such objects.
There was a problem hiding this comment.
The functions that are differentiable and return TensorList are: torch.unbind, torch.split (both of which have correct backward definition for complex). So, I think its ok to just not do anything for that case. However, I'll add check for TensorList type for any functions that maybe added in future.
There was a problem hiding this comment.
There are no other? Sounds good.
But yes a check is nice to make sure we don't break this in the future.
There was a problem hiding this comment.
There's _cudnn_rnn_backward but it's non-differentiable.
Added split', split_with_sizes, unsafe_split, split_with_sizes_backward to the list and also added a check to error out for tensorlist otherwise.
| "but one of the arguments requires grad."); | ||
| } | ||
|
|
||
| inline void throw_error_for_complex_fns_backward_not_implemented(const Tensor& tensor, const char* name) { |
There was a problem hiding this comment.
nit: name looks overly verbose.
This PR disables autograd for all C -> C, R -> C functions which are not included in the whitelist `GRADIENT_IMPLEMENTED_FOR_COMPLEX`. In practice, there will be a RuntimeError during forward computation when the outputs are differentiable: ``` >>> x=torch.randn(4, 4, requires_grad=True, dtype=torch.cdouble) >>> x.pow(3) Traceback (most recent call last): File "<stdin>", line 1, in <module> RuntimeError: pow does not support automatic differentiation for outputs with complex dtype. ``` The implicit assumption here is that all the C -> R functions have correct backward definitions. So before merging this PR, the following functions must be tested and verified to have correct backward definitions: `torch.abs` (updated in #39955 ), `torch.angle`, `torch.norm`, `torch.irfft`, `torch.istft`. [ghstack-poisoned]
| } | ||
|
|
||
| inline void throw_error_for_complex_fns_backward_not_implemented(const Tensor& tensor, const char* name) { | ||
| if (tensor.requires_grad() && tensor.is_complex()) { |
This PR disables autograd for all C -> C, R -> C functions which are not included in the whitelist `GRADIENT_IMPLEMENTED_FOR_COMPLEX`. In practice, there will be a RuntimeError during forward computation when the outputs are differentiable: ``` >>> x=torch.randn(4, 4, requires_grad=True, dtype=torch.cdouble) >>> x.pow(3) Traceback (most recent call last): File "<stdin>", line 1, in <module> RuntimeError: pow does not support automatic differentiation for outputs with complex dtype. ``` The implicit assumption here is that all the C -> R functions have correct backward definitions. So before merging this PR, the following functions must be tested and verified to have correct backward definitions: `torch.abs` (updated in #39955 ), `torch.angle`, `torch.norm`, `torch.irfft`, `torch.istft`. [ghstack-poisoned]
This PR disables autograd for all C -> C, R -> C functions which are not included in the whitelist `GRADIENT_IMPLEMENTED_FOR_COMPLEX`. In practice, there will be a RuntimeError during forward computation when the outputs are differentiable: ``` >>> x=torch.randn(4, 4, requires_grad=True, dtype=torch.cdouble) >>> x.pow(3) Traceback (most recent call last): File "<stdin>", line 1, in <module> RuntimeError: pow does not support automatic differentiation for outputs with complex dtype. ``` The implicit assumption here is that all the C -> R functions have correct backward definitions. So before merging this PR, the following functions must be tested and verified to have correct backward definitions: `torch.abs` (updated in #39955 ), `torch.angle`, `torch.norm`, `torch.irfft`, `torch.istft`. [ghstack-poisoned]
|
Thanks, looks good. cc @robieta I'm expecting a mild increase in instruction count here for AD benchmarks, will be more pronounced for operators on tensor lists. |
|
I think the test failures are OK. I'm not sure why the XLA build is failing but it doesn't look related to this PR. |
|
@ezyang I think (almost) all the TensorList op we have are actually in the allowlist. So the impact should be minimal. |
This PR disables autograd for all C -> C, R -> C functions which are not included in the whitelist `GRADIENT_IMPLEMENTED_FOR_COMPLEX`. In practice, there will be a RuntimeError during forward computation when the outputs are differentiable: ``` >>> x=torch.randn(4, 4, requires_grad=True, dtype=torch.cdouble) >>> x.pow(3) Traceback (most recent call last): File "<stdin>", line 1, in <module> RuntimeError: pow does not support automatic differentiation for outputs with complex dtype. ``` The implicit assumption here is that all the C -> R functions have correct backward definitions. So before merging this PR, the following functions must be tested and verified to have correct backward definitions: `torch.abs` (updated in #39955 ), `torch.angle`, `torch.norm`, `torch.irfft`, `torch.istft`. Differential Revision: [D23998156](https://our.internmc.facebook.com/intern/diff/D23998156) [ghstack-poisoned]
This PR disables autograd for all C -> C, R -> C functions which are not included in the whitelist `GRADIENT_IMPLEMENTED_FOR_COMPLEX`. In practice, there will be a RuntimeError during forward computation when the outputs are differentiable: ``` >>> x=torch.randn(4, 4, requires_grad=True, dtype=torch.cdouble) >>> x.pow(3) Traceback (most recent call last): File "<stdin>", line 1, in <module> RuntimeError: pow does not support automatic differentiation for outputs with complex dtype. ``` The implicit assumption here is that all the C -> R functions have correct backward definitions. So before merging this PR, the following functions must be tested and verified to have correct backward definitions: `torch.abs` (updated in #39955 ), `torch.angle`, `torch.norm`, `torch.irfft`, `torch.istft`. Differential Revision: [D23998156](https://our.internmc.facebook.com/intern/diff/D23998156) [ghstack-poisoned]
|
@anjali411 merged this pull request in 415ed43. |
Summary: Pull Request resolved: pytorch#45461 This PR disables autograd for all C -> C, R -> C functions which are not included in the whitelist `GRADIENT_IMPLEMENTED_FOR_COMPLEX`. In practice, there will be a RuntimeError during forward computation when the outputs are differentiable: ``` >>> x=torch.randn(4, 4, requires_grad=True, dtype=torch.cdouble) >>> x.pow(3) Traceback (most recent call last): File "<stdin>", line 1, in <module> RuntimeError: pow does not support automatic differentiation for outputs with complex dtype. ``` The implicit assumption here is that all the C -> R functions have correct backward definitions. So before merging this PR, the following functions must be tested and verified to have correct backward definitions: `torch.abs` (updated in pytorch#39955 ), `torch.angle`, `torch.norm`, `torch.irfft`, `torch.istft`. Test Plan: Imported from OSS Reviewed By: malfet Differential Revision: D23998156 Pulled By: anjali411 fbshipit-source-id: 370eb07fe56ac84dd8e2233ef7bf3a3eb8aeb179
Stack from ghstack:
This PR disables autograd for all C -> C, R -> C functions which are not included in the allowlist
GRADIENT_IMPLEMENTED_FOR_COMPLEX. In practice, there will be a RuntimeError during forward computation when the outputs are differentiable:The implicit assumption here is that all the C -> R functions have correct backward definitions. So before merging this PR, the following functions must be tested and verified to have correct backward definitions:
torch.abs(updated in #39955 ),torch.angle,torch.norm,torch.irfft,torch.istft.Differential Revision: D23998156