Added CUDA support for complex input for torch.cholesky_solve#47047
Added CUDA support for complex input for torch.cholesky_solve#47047IvanYashchuk wants to merge 21 commits intopytorch:masterfrom
Conversation
for complex dtype input
4c73f09 to
d2ffbf2
Compare
💊 CI failures summary and remediationsAs of commit 09d7262 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 43 times. |
Codecov Report
@@ Coverage Diff @@
## master #47047 +/- ##
=======================================
Coverage 80.79% 80.79%
=======================================
Files 1865 1865
Lines 201074 201074
=======================================
+ Hits 162456 162459 +3
+ Misses 38618 38615 -3 |
|
Hi @IvanYashchuk! Thank you for your pull request. We require contributors to sign our Contributor License Agreement, and yours needs attention. You currently have a record in our system, but we do not have a signature on file. In order for us to review and merge your code, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. If you have received this in error or have any questions, please contact us at cla@fb.com. Thanks! |
| MAGMAQueue magma_queue(self.get_device()); | ||
|
|
||
| constexpr int64_t batch_limit = 262140; | ||
| int64_t batch_limit = self.is_complex() ? 65535 : 262140; |
There was a problem hiding this comment.
why do we have different batch limit for complex and non-complex dtypes? can you link me to where this is documented?
There was a problem hiding this comment.
I don't know whether it's documented somewhere, I determined this value via experiments.
There was a problem hiding this comment.
CUDA limits kernel launches to y and z grid dimension to 65535. Maybe for non-complex dtypes batching is implemented differently allowing 262140 batches.
There was a problem hiding this comment.
synced with @ngimel offline. We should check the magma manual, and better document this difference in the batch_limit since the original comments are uninformative.
There was a problem hiding this comment.
It's not documented in magma.
CUDA limits kernel launch configurations of y and z grid dimensions to 65535.
https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications
| Maximum x-dimension of a grid of thread blocks | 2^31-1 |
|---|---|
| Maximum y- or z-dimension of a grid of thread blocks | 65535 |
I haven't checked the source code for how batching is done for non-complex dtypes, but apparently, complex variants use z-dimension of a grid of thread blocks for batching.
There was a problem hiding this comment.
I spent a little time looking at the complex path and didn't figure it out, but I did see this:
if ( n > 2048 ) {
#ifndef MAGMA_NOWARNING
printf("=========================================================================================\n"
" WARNING batched routines are designed for small sizes. It might be better to use the\n"
" Native/Hybrid classical routines if you want good performance.\n"
"=========================================================================================\n");
#endif
}
in magma_cpotrf_lg_batched
There was a problem hiding this comment.
Yeah we should use cusolver for those, if we don't already.
There was a problem hiding this comment.
cc @heitorschueroff, @xwang233 Can you guys please create a tracking issue which linalg functions under which conditions use magma or cusolver or cublas, and which functions still need to be weaned off magma and switched to cusolver?
There was a problem hiding this comment.
Thanks, I'll create a tracking issue.
| A = root.tril() | ||
| return torch.cholesky_solve(b, A, upper) | ||
|
|
||
| gradcheck(func, [root, b, upper]) |
There was a problem hiding this comment.
@IvanYashchuk please move the autograd tests to common_methods_invocations.py
There was a problem hiding this comment.
I think common_methods_invocations.py does not allow specifying the input function to be tested, it allows specifying only the postprocessing function.
Finite differencing doesn't work correctly for torch.cholesky_solve directly, therefore
def func(A, b, upper):
if upper:
A = A.triu()
else:
A = A.tril()
return torch.cholesky_solve(b, A, upper)is tested instead.
There was a problem hiding this comment.
I see. I also synced with @mruberry offline and we came to the conclusion it's ok to add autograd tests in test_linalg.py.
facebook-github-bot
left a comment
There was a problem hiding this comment.
@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
a55cd87 to
24979e6
Compare
|
@mruberry, I think we are ready to import this PR. |
facebook-github-bot
left a comment
There was a problem hiding this comment.
@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
Sorry @IvanYashchuk, looks like this picked up a merge conflict. Would you rebase? |
|
Done. |
facebook-github-bot
left a comment
There was a problem hiding this comment.
@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
…h#47047) Summary: `torch.cholesky_solve` now works for complex inputs on GPU. I moved the existing tests to `test_linalg.py` and modified them to test complex and float32 dtypes. Differentiation also works correctly with complex inputs now. Ref. pytorch#33152 Pull Request resolved: pytorch#47047 Reviewed By: ngimel Differential Revision: D24730020 Pulled By: mruberry fbshipit-source-id: 95402da5789c56e5a682019790985207fa28fa1f
torch.cholesky_solvenow works for complex inputs on GPU.I moved the existing tests to
test_linalg.pyand modified them to test complex and float32 dtypes.Differentiation also works correctly with complex inputs now.
Ref. #33152