[CPU] Add torch.trace for complex tensors#50380
[CPU] Add torch.trace for complex tensors#50380anjali411 wants to merge 10 commits intogh/anjali411/79/basefrom
Conversation
[ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit 76d05d5 (more details on the Dr. CI page):
🚧 1 fixed upstream failure:These were probably caused by upstream breakages that were already fixed.
Please rebase on the
|
| // all integer types get promoted to kLong | ||
| if (result.scalar_type() == at::kLong) { | ||
| // all integer types get promoted to kLong | ||
| *result.data_ptr<int64_t>() = sum; |
There was a problem hiding this comment.
separated dispatch because this cast causes a compile error for complex otherwise
There was a problem hiding this comment.
This seems like a good opportunity to use a if_constexpr to fix the problem. Test for std::is_integral<scalar_t> . Check if_constexpr docs for how to use it.
There was a problem hiding this comment.
still gives a compile error
Codecov Report
@@ Coverage Diff @@
## gh/anjali411/79/base #50380 +/- ##
=====================================================
Coverage 80.71% 80.71%
=====================================================
Files 1904 1904
Lines 206598 206600 +2
=====================================================
+ Hits 166750 166753 +3
+ Misses 39848 39847 -1 |
| using accscalar_t = at::acc_type<scalar_t, false>; | ||
| accscalar_t sum = 0; | ||
| const auto* t_data = self.data_ptr<scalar_t>(); | ||
| if (result.scalar_type() == at::kLong) { |
There was a problem hiding this comment.
Hm, is this test right? What about other integral types besides Long?
There was a problem hiding this comment.
Surprisingly this is kinda the right test.
if (dtype == at::kLong) {
...
}
might be clearer. Add a comment here explaining that all integer types (including bool) will return kLong from get_dtype(), and that this branch handles integer types.
There was a problem hiding this comment.
Why is the new branch needed, however? It seems like, as in the previous kernel, the only different is at the assignment to result at the end?
There was a problem hiding this comment.
yeah it is a bit confusing with no comment, so I added one for clarity.
There was a problem hiding this comment.
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
|
@mruberry I updated the PR based on your comments. could you take a look? |
[ghstack-poisoned]
[ghstack-poisoned]
|
Looks like CUDA supports bool and bfloat16, so those need to be added to the OpInfo's CUDA dtypes. |
Differential Revision: [D25949361](https://our.internmc.facebook.com/intern/diff/D25949361) [ghstack-poisoned]
ezyang
left a comment
There was a problem hiding this comment.
Approving to unblock, but please give the if_constexpr trick a try, I think you can avoid the duplication with it.
Differential Revision: [D25949361](https://our.internmc.facebook.com/intern/diff/D25949361) [ghstack-poisoned]
Differential Revision: [D25949361](https://our.internmc.facebook.com/intern/diff/D25949361) [ghstack-poisoned]
| c10::guts::if_constexpr<std::is_integral<accscalar_t>::value>( | ||
| // all integer types get promoted to kLong | ||
| [&] (auto _) { *result.data_ptr<int64_t>() = sum; }, // then-case, invalid for non-integral types | ||
| [&] (auto _) { *result.data_ptr<scalar_t>() = sum; } // else-case, invalid for integral types |
There was a problem hiding this comment.
The compiler is eagerly type-checking the else-case even if the condition is true (and the other way round). You need to prevent the compiler from doing so, so that it can only type-check the branch that's actually called. You can do this by wrapping things into the _ you get passed as an argument. See https://github.com/pytorch/pytorch/blob/master/c10/util/C%2B%2B17.h#L257
Differential Revision: [D25949361](https://our.internmc.facebook.com/intern/diff/D25949361) [ghstack-poisoned]
|
@anjali411 merged this pull request in e544d74. |
Summary: Pull Request resolved: pytorch#50380 Test Plan: Imported from OSS Reviewed By: ezyang Differential Revision: D25949361 Pulled By: anjali411 fbshipit-source-id: 9910bc5b532c9bf3add530221d643b2c41c62d01
Stack from ghstack:
Differential Revision: D25949361