Skip to content

Port bmm and baddbmm from TH to ATen#42553

Closed
anjali411 wants to merge 30 commits intogh/anjali411/46/basefrom
gh/anjali411/46/head
Closed

Port bmm and baddbmm from TH to ATen#42553
anjali411 wants to merge 30 commits intogh/anjali411/46/basefrom
gh/anjali411/46/head

Conversation

@anjali411
Copy link
Copy Markdown
Contributor

@anjali411 anjali411 commented Aug 4, 2020

Stack from ghstack:

Ports torch.bmm and torch.baddbmm from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions.

Closes #24539

Differential Revision: D24893511

anjali411 added a commit that referenced this pull request Aug 4, 2020
ghstack-source-id: b737e98
Pull Request resolved: #42553
@dr-ci
Copy link
Copy Markdown

dr-ci Bot commented Aug 4, 2020

💊 CI failures summary and remediations

As of commit 909b366 (more details on the Dr. CI page):


  • 1/1 failures possibly* introduced in this PR
    • 1/1 non-CircleCI failure(s)

ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 162 times.

@anjali411 anjali411 requested a review from zasdfgbnm August 4, 2020 19:30
@anjali411 anjali411 changed the title Port bmm and baddbmm from TH to ATen [WIP] Port bmm and baddbmm from TH to ATen Aug 4, 2020
anjali411 added a commit that referenced this pull request Aug 5, 2020
ghstack-source-id: 3eba02a
Pull Request resolved: #42553
Comment thread aten/src/THC/THCBlas.cu Outdated
Comment thread aten/src/THC/THCBlas.cu Outdated
@anjali411 anjali411 requested a review from zasdfgbnm August 5, 2020 22:26
anjali411 added a commit that referenced this pull request Aug 5, 2020
ghstack-source-id: 047dffe
Pull Request resolved: #42553
@anjali411 anjali411 changed the title [WIP] Port bmm and baddbmm from TH to ATen Port bmm and baddbmm from TH to ATen Aug 6, 2020
Copy link
Copy Markdown
Collaborator

@zasdfgbnm zasdfgbnm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not finished yet. Will post more comment later.

Comment thread aten/src/ATen/cuda/CUDABlas.h Outdated
Comment thread aten/src/ATen/cuda/CUDABlas.cpp
Comment thread aten/src/ATen/cuda/CUDABlas.cpp Outdated
Comment thread aten/src/ATen/native/cuda/LinearAlgebra.cu Outdated
Comment thread aten/src/ATen/cuda/CUDABlas.cpp Outdated
Comment thread aten/src/ATen/native/cuda/LinearAlgebra.cu Outdated
Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. 

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Aug 7, 2020
ghstack-source-id: f45d4e1
Pull Request resolved: #42553
Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. 

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Aug 12, 2020
ghstack-source-id: 61c04d4
Pull Request resolved: #42553
Comment thread aten/src/ATen/cuda/CUDABlas.h Outdated
Comment thread aten/src/ATen/cuda/CUDABlas.h Outdated
Comment thread aten/src/ATen/native/cuda/LinearAlgebra.cu Outdated
Comment thread aten/src/ATen/native/cuda/LinearAlgebra.cu Outdated
Comment thread aten/src/ATen/native/cuda/LinearAlgebra.cu Outdated
Comment thread aten/src/ATen/native/cuda/LinearAlgebra.cu Outdated
Comment thread aten/src/ATen/native/cuda/LinearAlgebra.cu Outdated
Comment thread aten/src/ATen/native/cuda/LinearAlgebra.cu Outdated
Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. 

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Aug 13, 2020
ghstack-source-id: b43d308
Pull Request resolved: #42553
@ngimel
Copy link
Copy Markdown
Collaborator

ngimel commented Aug 13, 2020

Per @gchanan's request ports from TH to ATen should also beef up test coverage (in particular, various discontiguity patterns on input/output, and proper runtime errors for arguments on the different devices).

@zasdfgbnm
Copy link
Copy Markdown
Collaborator

@anjali411 Could you please rebase? Looks like there are lots of flaky tests.

Comment thread aten/src/ATen/native/cuda/LinearAlgebra.cu Outdated
Comment thread aten/src/ATen/cuda/CUDABlas.cpp Outdated
Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. 

[ghstack-poisoned]
Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. 

Closes #24539

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Nov 11, 2020
ghstack-source-id: 3469afd
Pull Request resolved: #42553
Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. 

Closes #24539

[ghstack-poisoned]
Comment thread test/test_torch.py Outdated
Comment thread test/test_torch.py Outdated
@skipCUDAIf(torch.version.cuda == "10.1", "flaky on CUDA 10.1")
@onlyOnCPUAndCUDA
@dtypes(*torch.testing.get_all_fp_dtypes(), *torch.testing.get_all_complex_dtypes())
@dtypesIfCUDA(*(torch.testing.get_all_fp_dtypes(include_half=True, include_bfloat16=AMPERE_OR_ROCM) +
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't do so. We test on all dtypes on purpose to make sure that all dtypes are tested: if it is supported, then it should run well. If it is not supported, it should raise an error.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc: @ngimel

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

synced offline: the cpu bmm and baddbmm has multiple code paths, some of them supports bfloat16 and float16, some don't. So depending on the input, half and bfloat could or could not be supported. https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/LinearAlgebra.cpp#L498

So @zasdfgbnm , @ngimel and I agreed to add full support for torch.float16 and torch.bfloat16 in a follow-up PR and leave this one as is.

Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. 

Closes #24539

Differential Revision: [D24893511](https://our.internmc.facebook.com/intern/diff/D24893511)

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Nov 11, 2020
ghstack-source-id: 0bbc3aa
Pull Request resolved: #42553
@anjali411 anjali411 requested a review from zasdfgbnm November 11, 2020 18:57
Comment thread test/test_torch.py Outdated
Comment thread test/test_torch.py Outdated
Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. 

Closes #24539

Differential Revision: [D24893511](https://our.internmc.facebook.com/intern/diff/D24893511)

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Nov 12, 2020
ghstack-source-id: d7ff7cd
Pull Request resolved: #42553
Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions. 

Closes #24539

Differential Revision: [D24893511](https://our.internmc.facebook.com/intern/diff/D24893511)

[ghstack-poisoned]
anjali411 added a commit that referenced this pull request Nov 12, 2020
ghstack-source-id: c52815b
Pull Request resolved: #42553
Copy link
Copy Markdown
Collaborator

@zasdfgbnm zasdfgbnm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for working on this!

@codecov
Copy link
Copy Markdown

codecov Bot commented Nov 12, 2020

Codecov Report

Merging #42553 (909b366) into gh/anjali411/46/base (4738672) will increase coverage by 0.00%.
The diff coverage is 0.00%.

@@                  Coverage Diff                  @@
##           gh/anjali411/46/base   #42553   +/-   ##
=====================================================
  Coverage                 81.22%   81.22%           
=====================================================
  Files                      1837     1837           
  Lines                    198087   198087           
=====================================================
+ Hits                     160893   160897    +4     
+ Misses                    37194    37190    -4     

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@anjali411 merged this pull request in e1ee3bf.

@facebook-github-bot facebook-github-bot deleted the gh/anjali411/46/head branch November 16, 2020 15:17
facebook-github-bot pushed a commit that referenced this pull request Nov 18, 2020
Summary:
Now when #42553 is merged we can delete a bit of code from the tests and enable some of the skipped complex tests.

Unfortunately, `test_pinverse_complex_xfailed` and `test_symeig_complex_xfailed` had bugs and it wasn't caught automatically that these tests xpass. Need to be careful next time with `unittest.expectedFailure`.

Pull Request resolved: #47910

Reviewed By: zhangguanheng66

Differential Revision: D25052130

Pulled By: mruberry

fbshipit-source-id: 29512995c024b882f9cb78b7bede77733d5762d0
/* LEVEL 3 BLAS FUNCTIONS */

#ifndef __HIP_PLATFORM_HCC__
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11200
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this macro CUDA_VERSION >= 11200 intended? If you mean cuda 11.2, it should be 11020. I'm not sure if cuda 11.2 was a thing back in November 2020. 😅

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No harm done, workaround is good.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xwang233 no my bad! we should fix that to avoid confusion in future

laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
…ytorch#45737)

Summary:
This PR updates derivatives for a few functions so that `gradgradcheck` for `torch.cholesky` is passed ([ref](pytorch#45267 (comment))).

Some tests (that call to `bmm_cuda`) fail with with `RuntimeError: _th_bmm_out not supported on CUDAType for ComplexDouble`
until PR pytorch#42553 is merged.

Ref. pytorch#33152

Pull Request resolved: pytorch#45737

Reviewed By: bdhirsh

Differential Revision: D24279917

Pulled By: anjali411

fbshipit-source-id: 7b696d2cfc2ef714332c2e3e5d207e257be67744
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
Summary:
This is to satisfy the request at pytorch#42553 (comment). See also pytorch#47124

Pull Request resolved: pytorch#47079

Reviewed By: ejguan

Differential Revision: D24735356

Pulled By: ngimel

fbshipit-source-id: 122fceb4902658f350c2fd6f92455adadd0ec2a4
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
Summary:
Pull Request resolved: pytorch#42553

Ports `torch.bmm` and `torch.baddbmm` from TH to ATen, as well as adds support for complex dtypes. Also removes dead TH code for Level 2 functions.

Closes pytorch#24539

Test Plan: Imported from OSS

Reviewed By: ansley

Differential Revision: D24893511

Pulled By: anjali411

fbshipit-source-id: 0eba3f2aec99c48b3018a5264ee7789279cfab58
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
…7910)

Summary:
Now when pytorch#42553 is merged we can delete a bit of code from the tests and enable some of the skipped complex tests.

Unfortunately, `test_pinverse_complex_xfailed` and `test_symeig_complex_xfailed` had bugs and it wasn't caught automatically that these tests xpass. Need to be careful next time with `unittest.expectedFailure`.

Pull Request resolved: pytorch#47910

Reviewed By: zhangguanheng66

Differential Revision: D25052130

Pulled By: mruberry

fbshipit-source-id: 29512995c024b882f9cb78b7bede77733d5762d0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants