Add complex support for torch.mean [CUDA]#47048
Add complex support for torch.mean [CUDA]#47048RockingJavaBean wants to merge 6 commits intopytorch:masterfrom
Conversation
💊 CI failures summary and remediationsAs of commit de87d22 (more details on the Dr. CI page):
🕵️ 1 new failure recognized by patternsThe following CI failures do not appear to be due to upstream breakages:
|
d618111 to
b0de1be
Compare
Codecov Report
@@ Coverage Diff @@
## master #47048 +/- ##
===========================================
+ Coverage 35.95% 53.27% +17.31%
===========================================
Files 438 2747 +2309
Lines 55454 254304 +198850
===========================================
+ Hits 19939 135476 +115537
- Misses 35515 118828 +83313 |
|
@anjali411 Thanks so much for reviewing this PR. |
| @@ -36,26 +36,40 @@ static void std_var_kernel_cuda(TensorIterator& iter, bool unbiased, bool take_s | |||
|
|
|||
| template <typename scalar_t, typename acc_t=scalar_t, typename out_t=scalar_t> | |||
There was a problem hiding this comment.
setting acc_t = typename c10::scalar_value_type<scalar_t>::type should resolve the issue here
c10::scalar_value_type<scalar_t>::type returns scalar_t for all non-complex dtypes and returns T for c10::complex<T>.
There was a problem hiding this comment.
Thanks so much for the tip. The latest code now manipulates c10::scalar_value_type<acc_t>::type to get the type of the factor, the overload functions for complex numbers are not needed.
Hi @RockingJavaBean I think we shouldn't need to define overload functions for complex types, after the change I suggested in my comment. But this PR looks great overall, and should be ready to merge after that change! |
facebook-github-bot
left a comment
There was a problem hiding this comment.
@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
| lambda n, d: n.mean(d), | ||
| use_integral=False) | ||
| use_integral=False, | ||
| use_complex=True) |
There was a problem hiding this comment.
This test doesn't run on CUDA. can you please extend the test for mean in tensor_op_tests to also test complex dtypes?
There was a problem hiding this comment.
Thanks for pointing this out, the tests for complex dtypes are added to tensor_op_tests.
anjali411
left a comment
There was a problem hiding this comment.
let's update the CUDA test for mean to test complex dtypes as well
|
@anjali411 I'm really grateful for your tip on |
anjali411
left a comment
There was a problem hiding this comment.
lgtm and the windows test failure is an upstream test failure
facebook-github-bot
left a comment
There was a problem hiding this comment.
@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
@anjali411 thanks so much for reviewing this PR, the CUDA tests for |
|
@RockingJavaBean can you please rebase? |
…h_mean_complex
|
@anjali411 thank you so much for the kind reminder, I just rebased this PR with the latest code. |
facebook-github-bot
left a comment
There was a problem hiding this comment.
@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
@anjali411 merged this pull request in f90da88. |
Summary: Fixes pytorch#46982 Pull Request resolved: pytorch#47048 Reviewed By: heitorschueroff Differential Revision: D24729895 Pulled By: anjali411 fbshipit-source-id: 8e948480eb87c37de810207edf909375c0380772
Fixes #46982