support complex types for cumsum, cumprod#39063
support complex types for cumsum, cumprod#39063kshitij12345 wants to merge 2 commits intopytorch:masterfrom
cumsum, cumprod#39063Conversation
|
@anjali411 Please review. |
💊 CI failures summary and remediationsAs of commit 23ebf53 (more details on the Dr. CI page):
ci.pytorch.org: 1 failedThis comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 5 times. |
There was a problem hiding this comment.
is this effectively just viewing complex as float? If yes, it would work for cumsum, however it won't work for cumprod as is, since multiplication for two complex numbers is not defined as a dot product. (a+bi)*(c+di) = ac-bd + (ad+bc)i
There was a problem hiding this comment.
Here we are allocating shared memory buffer and not viewing into the actual data.
Since we can't allocate the shared memory for the c10::complex as it requires explicit initialization. So instead we allocate the equivalent memory in terms of the base scalar type and view that as a buffer for c10::complex.
So we are not actually viewing it as independent base type memory in the tensor_kernel_scan_innermost_dim_impl but as actual c10::complex.
So cumprod won't be an issue.
Plus all the tests pass.
There was a problem hiding this comment.
Could you explain why we can't allocate the shared memory for the c10::complex as it requires explicit initialization.?
c10::complex has a default constructor, and shouldn't it be initialized to its default value automatically ?
Edit: never mind, I was thinking something differently
There was a problem hiding this comment.
I am not quite sure why we can't allocate the shared memory for the c10::complex either. @kshitij12345 can you explain how is it related to explicit initialization requirement of c10::complex?
There was a problem hiding this comment.
A minimum repro of this problem could be
struct A {
int a = 0;
A() = default;
};
__global__ void f(A *p) {
__shared__ A s[5]; // error: error: initializer not allowed for __shared__ variable
s[0] = *p;
}There was a problem hiding this comment.
offline sync with @zasdfgbnm :
__shared__ variables cannot have an initialization as part of their declaration as explained here http://www.ncsa.illinois.edu/People/kindr/projects/hpca/files/singapore_p2.pdf
updating the complex constructor to not assign real and imag in constructor would resolve the problem.
There was a problem hiding this comment.
@anjali411 @zasdfgbnm Thanks. I was actually misguided in understanding the error.
Thanks for the reference as well.
facebook-github-bot
left a comment
There was a problem hiding this comment.
@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
@kshitij12345 can you please rebase? |
ebe2c9b to
23ebf53
Compare
facebook-github-bot
left a comment
There was a problem hiding this comment.
@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
@anjali411 merged this pull request in 10e2126. |
Summary: Adds complex support to `cumsum`, `cumprod` and relevant test update in `test_torch::tensor_op_tests` Pull Request resolved: pytorch#39063 Differential Revision: D21771186 Pulled By: anjali411 fbshipit-source-id: 632916d4bdbd1c0941001898ab8146be2b7884fc
Adds complex support to
cumsum,cumprodand relevant test update intest_torch::tensor_op_tests