Port bmm and baddbmm from TH to ATen#42553
Port bmm and baddbmm from TH to ATen#42553anjali411 wants to merge 30 commits intogh/anjali411/46/basefrom
Conversation
[ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit 909b366 (more details on the Dr. CI page):
ci.pytorch.org: 1 failedThis comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 162 times. |
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
zasdfgbnm
left a comment
There was a problem hiding this comment.
Not finished yet. Will post more comment later.
Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. [ghstack-poisoned]
Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. [ghstack-poisoned]
Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. [ghstack-poisoned]
|
Per @gchanan's request ports from TH to ATen should also beef up test coverage (in particular, various discontiguity patterns on input/output, and proper runtime errors for arguments on the different devices). |
|
@anjali411 Could you please rebase? Looks like there are lots of flaky tests. |
Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. [ghstack-poisoned]
Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. Closes #24539 [ghstack-poisoned]
Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. Closes #24539 [ghstack-poisoned]
| @skipCUDAIf(torch.version.cuda == "10.1", "flaky on CUDA 10.1") | ||
| @onlyOnCPUAndCUDA | ||
| @dtypes(*torch.testing.get_all_fp_dtypes(), *torch.testing.get_all_complex_dtypes()) | ||
| @dtypesIfCUDA(*(torch.testing.get_all_fp_dtypes(include_half=True, include_bfloat16=AMPERE_OR_ROCM) + |
There was a problem hiding this comment.
Please don't do so. We test on all dtypes on purpose to make sure that all dtypes are tested: if it is supported, then it should run well. If it is not supported, it should raise an error.
There was a problem hiding this comment.
synced offline: the cpu bmm and baddbmm has multiple code paths, some of them supports bfloat16 and float16, some don't. So depending on the input, half and bfloat could or could not be supported. https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/LinearAlgebra.cpp#L498
So @zasdfgbnm , @ngimel and I agreed to add full support for torch.float16 and torch.bfloat16 in a follow-up PR and leave this one as is.
Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. Closes #24539 Differential Revision: [D24893511](https://our.internmc.facebook.com/intern/diff/D24893511) [ghstack-poisoned]
Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. Closes #24539 Differential Revision: [D24893511](https://our.internmc.facebook.com/intern/diff/D24893511) [ghstack-poisoned]
Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. Closes #24539 Differential Revision: [D24893511](https://our.internmc.facebook.com/intern/diff/D24893511) [ghstack-poisoned]
zasdfgbnm
left a comment
There was a problem hiding this comment.
LGTM! Thanks for working on this!
Codecov Report
@@ Coverage Diff @@
## gh/anjali411/46/base #42553 +/- ##
=====================================================
Coverage 81.22% 81.22%
=====================================================
Files 1837 1837
Lines 198087 198087
=====================================================
+ Hits 160893 160897 +4
+ Misses 37194 37190 -4 |
|
@anjali411 merged this pull request in e1ee3bf. |
Summary: Now when #42553 is merged we can delete a bit of code from the tests and enable some of the skipped complex tests. Unfortunately, `test_pinverse_complex_xfailed` and `test_symeig_complex_xfailed` had bugs and it wasn't caught automatically that these tests xpass. Need to be careful next time with `unittest.expectedFailure`. Pull Request resolved: #47910 Reviewed By: zhangguanheng66 Differential Revision: D25052130 Pulled By: mruberry fbshipit-source-id: 29512995c024b882f9cb78b7bede77733d5762d0
| /* LEVEL 3 BLAS FUNCTIONS */ | ||
|
|
||
| #ifndef __HIP_PLATFORM_HCC__ | ||
| #if defined(CUDA_VERSION) && CUDA_VERSION >= 11200 |
There was a problem hiding this comment.
Is this macro CUDA_VERSION >= 11200 intended? If you mean cuda 11.2, it should be 11020. I'm not sure if cuda 11.2 was a thing back in November 2020. 😅
There was a problem hiding this comment.
No harm done, workaround is good.
There was a problem hiding this comment.
@xwang233 no my bad! we should fix that to avoid confusion in future
…ytorch#45737) Summary: This PR updates derivatives for a few functions so that `gradgradcheck` for `torch.cholesky` is passed ([ref](pytorch#45267 (comment))). Some tests (that call to `bmm_cuda`) fail with with `RuntimeError: _th_bmm_out not supported on CUDAType for ComplexDouble` until PR pytorch#42553 is merged. Ref. pytorch#33152 Pull Request resolved: pytorch#45737 Reviewed By: bdhirsh Differential Revision: D24279917 Pulled By: anjali411 fbshipit-source-id: 7b696d2cfc2ef714332c2e3e5d207e257be67744
Summary: This is to satisfy the request at pytorch#42553 (comment). See also pytorch#47124 Pull Request resolved: pytorch#47079 Reviewed By: ejguan Differential Revision: D24735356 Pulled By: ngimel fbshipit-source-id: 122fceb4902658f350c2fd6f92455adadd0ec2a4
Summary: Pull Request resolved: pytorch#42553 Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. Closes pytorch#24539 Test Plan: Imported from OSS Reviewed By: ansley Differential Revision: D24893511 Pulled By: anjali411 fbshipit-source-id: 0eba3f2aec99c48b3018a5264ee7789279cfab58
…7910) Summary: Now when pytorch#42553 is merged we can delete a bit of code from the tests and enable some of the skipped complex tests. Unfortunately, `test_pinverse_complex_xfailed` and `test_symeig_complex_xfailed` had bugs and it wasn't caught automatically that these tests xpass. Need to be careful next time with `unittest.expectedFailure`. Pull Request resolved: pytorch#47910 Reviewed By: zhangguanheng66 Differential Revision: D25052130 Pulled By: mruberry fbshipit-source-id: 29512995c024b882f9cb78b7bede77733d5762d0
Stack from ghstack:
Ports
torch.bmmandtorch.baddbmmfrom TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions.Closes #24539
Differential Revision: D24893511