Skip to content

support complex types for cumsum, cumprod#39063

Closed
kshitij12345 wants to merge 2 commits intopytorch:masterfrom
kshitij12345:support/complex/cumsum-cumprod-cuda
Closed

support complex types for cumsum, cumprod#39063
kshitij12345 wants to merge 2 commits intopytorch:masterfrom
kshitij12345:support/complex/cumsum-cumprod-cuda

Conversation

@kshitij12345
Copy link
Copy Markdown
Collaborator

Adds complex support to cumsum, cumprod and relevant test update in test_torch::tensor_op_tests

@kshitij12345
Copy link
Copy Markdown
Collaborator Author

@anjali411 Please review.

@dr-ci
Copy link
Copy Markdown

dr-ci Bot commented May 27, 2020

💊 CI failures summary and remediations

As of commit 23ebf53 (more details on the Dr. CI page):


  • 1/1 failures possibly* introduced in this PR
    • 1/1 non-CircleCI failure(s)

ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 5 times.

@zou3519 zou3519 requested a review from anjali411 May 27, 2020 13:40
@zou3519 zou3519 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 27, 2020
@anjali411 anjali411 added the module: complex Related to complex number support in PyTorch label May 27, 2020
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this effectively just viewing complex as float? If yes, it would work for cumsum, however it won't work for cumprod as is, since multiplication for two complex numbers is not defined as a dot product. (a+bi)*(c+di) = ac-bd + (ad+bc)i

Copy link
Copy Markdown
Collaborator Author

@kshitij12345 kshitij12345 May 28, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we are allocating shared memory buffer and not viewing into the actual data.

Since we can't allocate the shared memory for the c10::complex as it requires explicit initialization. So instead we allocate the equivalent memory in terms of the base scalar type and view that as a buffer for c10::complex.

So we are not actually viewing it as independent base type memory in the tensor_kernel_scan_innermost_dim_impl but as actual c10::complex.
So cumprod won't be an issue.

Plus all the tests pass.

Copy link
Copy Markdown
Collaborator

@zasdfgbnm zasdfgbnm May 28, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain why we can't allocate the shared memory for the c10::complex as it requires explicit initialization.?
c10::complex has a default constructor, and shouldn't it be initialized to its default value automatically ?

Edit: never mind, I was thinking something differently

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not quite sure why we can't allocate the shared memory for the c10::complex either. @kshitij12345 can you explain how is it related to explicit initialization requirement of c10::complex?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A minimum repro of this problem could be

struct A {
    int a = 0;
    A() = default;
};

__global__ void f(A *p) {
    __shared__ A s[5];  // error: error: initializer not allowed for __shared__ variable 
    s[0] = *p;
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

offline sync with @zasdfgbnm :
__shared__ variables cannot have an initialization as part of their declaration as explained here http://www.ncsa.illinois.edu/People/kindr/projects/hpca/files/singapore_p2.pdf

updating the complex constructor to not assign real and imag in constructor would resolve the problem.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anjali411 @zasdfgbnm Thanks. I was actually misguided in understanding the error.

Thanks for the reference as well.

@anjali411 anjali411 requested a review from zasdfgbnm May 28, 2020 19:18
Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@anjali411
Copy link
Copy Markdown
Contributor

@kshitij12345 can you please rebase?

@kshitij12345 kshitij12345 force-pushed the support/complex/cumsum-cumprod-cuda branch from ebe2c9b to 23ebf53 Compare May 29, 2020 05:34
Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@kshitij12345 kshitij12345 deleted the support/complex/cumsum-cumprod-cuda branch May 29, 2020 16:53
@facebook-github-bot
Copy link
Copy Markdown
Contributor

@anjali411 merged this pull request in 10e2126.

laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
Summary:
Adds complex support to `cumsum`, `cumprod` and relevant test update in `test_torch::tensor_op_tests`
Pull Request resolved: pytorch#39063

Differential Revision: D21771186

Pulled By: anjali411

fbshipit-source-id: 632916d4bdbd1c0941001898ab8146be2b7884fc
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: complex Related to complex number support in PyTorch open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants