Add view_as_real, view_as_complex for complex tensors#39099
Add view_as_real, view_as_complex for complex tensors#39099anjali411 wants to merge 38 commits intogh/anjali411/32/basefrom
Conversation
[ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit 93e3cf5 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 238 times. |
[WIP] TODO: 1. add documentation 2. add tests 3. autograd code [ghstack-poisoned]
[WIP] TODO: 1. add documentation 2. add tests 3. autograd code [ghstack-poisoned]
[WIP] TODO: 1. add documentation 2. add tests 3. autograd code [ghstack-poisoned]
[WIP] TODO: 1. add documentation 2. add tests 3. autograd code [ghstack-poisoned]
[WIP] TODO: 1. add documentation 2. add tests 3. autograd code [ghstack-poisoned]
[WIP] TODO: 1. add documentation 2. add tests 3. autograd code [ghstack-poisoned]
[WIP] TODO: 1. add documentation 2. add tests 3. autograd code [ghstack-poisoned]
[WIP] TODO: 1. add documentation 2. add tests 3. autograd code [ghstack-poisoned]
[WIP] TODO: 1. add documentation 2. add tests 3. autograd code [ghstack-poisoned]
[WIP] TODO: 1. add documentation 2. add tests 3. autograd code [ghstack-poisoned]
[WIP] TODO: 1. add documentation 2. add tests 3. autograd code [ghstack-poisoned]
[WIP] TODO: 1. add documentation 2. add tests 3. autograd code [ghstack-poisoned]
| x.real.sum().backward() | ||
| self.assertEqual(x.grad, torch.ones_like(x)) | ||
|
|
||
| # remove this test after gradcheck support is added for non-holomoprphic functions |
There was a problem hiding this comment.
what about a test where you manipulate both the real and imag parts and then backward through both those results?
There was a problem hiding this comment.
I don't see a case that this test covers which hasn't already been covered
|
|
||
| setattr(TestAutogradDeviceType, test_name, do_test) | ||
|
|
||
| class TestAutogradComplex(TestCase): |
There was a problem hiding this comment.
Should this class be made device generic so it runs on CPU and CUDA?
There was a problem hiding this comment.
Ithink autograd tests run for both cpu and cuda already
There was a problem hiding this comment.
These tests won't run on both but that's OK for now.
Differential Revision: [D22057886](https://our.internmc.facebook.com/intern/diff/D22057886) [ghstack-poisoned]
Differential Revision: [D22057886](https://our.internmc.facebook.com/intern/diff/D22057886) [ghstack-poisoned]
Differential Revision: [D22057886](https://our.internmc.facebook.com/intern/diff/D22057886) This PR adds the following: 1. `view_as_real`, `view_as_complex` for complex tensors 2. Adds the replay logic for these complex views that cause a change in the metadata (in this case- dtype): `as_strided` doesn't carry the tensor metadata information so this PR adds a replay logic similar to what `XLA` does for these view functions that cause a change in metadata. There are four possible scenarios concerning this PR: 1. view function that uses as_strided followed by a view function that uses replay (example. `.transpose(...)` followed by `view_as_real`) 2. view function that uses replay followed by a view function that uses as_strided (example. `view_as_real` followed by `.transpose(...)`) 3. view function that uses replay followed by another view function that uses replay (example. `view_as_real` followed by `view_as_complex`) 4. view function that uses as_strided followed by another view function that uses as_strided (this is what always happens for non-xla tensors rn) this covers case 1: https://github.com/pytorch/pytorch/pull/39099/files#diff-88d45e2b30fab214a6aa074fd30eb152R130 this covers case 2: https://github.com/pytorch/pytorch/pull/39099/files#diff-88d45e2b30fab214a6aa074fd30eb152R154 Case 3. is handled by existing XLA logic here: https://github.com/pytorch/pytorch/pull/39099/files#diff-88d45e2b30fab214a6aa074fd30eb152L120 [ghstack-poisoned]
mruberry
left a comment
There was a problem hiding this comment.
Cool!
Just remember to update docs/source/torch.rst.
Differential Revision: [D22057886](https://our.internmc.facebook.com/intern/diff/D22057886) This PR adds the following: 1. `view_as_real`, `view_as_complex` for complex tensors 2. Adds the replay logic for these complex views that cause a change in the metadata (in this case- dtype): `as_strided` doesn't carry the tensor metadata information so this PR adds a replay logic similar to what `XLA` does for these view functions that cause a change in metadata. There are four possible scenarios concerning this PR: 1. view function that uses as_strided followed by a view function that uses replay (example. `.transpose(...)` followed by `view_as_real`) 2. view function that uses replay followed by a view function that uses as_strided (example. `view_as_real` followed by `.transpose(...)`) 3. view function that uses replay followed by another view function that uses replay (example. `view_as_real` followed by `view_as_complex`) 4. view function that uses as_strided followed by another view function that uses as_strided (this is what always happens for non-xla tensors rn) this covers case 1: https://github.com/pytorch/pytorch/pull/39099/files#diff-88d45e2b30fab214a6aa074fd30eb152R130 this covers case 2: https://github.com/pytorch/pytorch/pull/39099/files#diff-88d45e2b30fab214a6aa074fd30eb152R154 Case 3. is handled by existing XLA logic here: https://github.com/pytorch/pytorch/pull/39099/files#diff-88d45e2b30fab214a6aa074fd30eb152L120 [ghstack-poisoned]
albanD
left a comment
There was a problem hiding this comment.
Thanks Mike for the update.
Good to go from my side !
|
@anjali411 merged this pull request in 8ec2ae9. |
Summary: Pull Request resolved: pytorch#39099 Test Plan: Imported from OSS Differential Revision: D22057886 Pulled By: anjali411 fbshipit-source-id: bad5ba7097ba0dd13f2c549b2463094dee9afa14
Stack from ghstack:
Differential Revision: D22057886
This PR adds the following:
view_as_real,view_as_complexfor complex tensorsas_strideddoesn't carry the tensor metadata information so this PR adds a replay logic similar to whatXLAdoes for these view functions that cause a change in metadata. There are four possible scenarios concerning this PR:.transpose(...)followed byview_as_real)view_as_realfollowed by.transpose(...))view_as_realfollowed byview_as_complex)this covers case 1: https://github.com/pytorch/pytorch/pull/39099/files#diff-88d45e2b30fab214a6aa074fd30eb152R130
this covers case 2: https://github.com/pytorch/pytorch/pull/39099/files#diff-88d45e2b30fab214a6aa074fd30eb152R154
Case 3. is handled by existing XLA logic here: https://github.com/pytorch/pytorch/pull/39099/files#diff-88d45e2b30fab214a6aa074fd30eb152L120