Skip to content

Add complex support for torch.sum#38382

Closed
zasdfgbnm wants to merge 23 commits intogh/zasdfgbnm/81/basefrom
gh/zasdfgbnm/81/head
Closed

Add complex support for torch.sum#38382
zasdfgbnm wants to merge 23 commits intogh/zasdfgbnm/81/basefrom
gh/zasdfgbnm/81/head

Conversation

@zasdfgbnm
Copy link
Copy Markdown
Collaborator

@zasdfgbnm zasdfgbnm commented May 13, 2020

Stack from ghstack:

Differential Revision: D21600127

zasdfgbnm added a commit that referenced this pull request May 13, 2020
ghstack-source-id: 48a68fd
Pull Request resolved: #38382
#endif
}

__device__ __forceinline__ unsigned int ACTIVE_MASK()
Copy link
Copy Markdown
Collaborator Author

@zasdfgbnm zasdfgbnm May 13, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved from THC to ATen, with support for old CUDA version removed.

}

template <typename T>
__device__ __forceinline__ c10::complex<T> WARP_SHFL_DOWN(c10::complex<T> value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the key to making it work for c10::complex.

@zasdfgbnm zasdfgbnm added the module: complex Related to complex number support in PyTorch label May 13, 2020
@zasdfgbnm zasdfgbnm requested a review from anjali411 May 13, 2020 06:43
@dr-ci
Copy link
Copy Markdown

dr-ci Bot commented May 13, 2020

💊 CI failures summary and remediations

As of commit d9f647f (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker.

See how this bot performed.

This comment has been revised 61 times.

zasdfgbnm added a commit that referenced this pull request May 13, 2020
ghstack-source-id: 46df8f4
Pull Request resolved: #38382
Copy link
Copy Markdown
Contributor

@anjali411 anjali411 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you also add test for backward here by adding sum to the white list

zasdfgbnm added a commit that referenced this pull request May 13, 2020
ghstack-source-id: 8cbb803
Pull Request resolved: #38382
@zasdfgbnm
Copy link
Copy Markdown
Collaborator Author

I added sum to the white list. Also, looks like your test in #37959 is more complete. So, after this gets merged, we could rebase #37959 so that this PR will be tested by #37959?

@anjali411
Copy link
Copy Markdown
Contributor

anjali411 commented May 13, 2020

I added sum to the white list. Also, looks like your test in #37959 is more complete. So, after this gets merged, we could rebase #37959 so that this PR will be tested by #37959?

I think we should add the sum related tests in this PR so that it's fully tested before merge:

('sum', 'complex', _small_2d, lambda t, d: [], 1e-2, 1e-2, 1e-5, _complex_types, _cpu_types, False),
('sum', 'complex_dim', _small_3d, lambda t, d: [1], 1e-2, 1e-2, 1e-5, _complex_types, _cpu_types, False),
('sum', 'complex_neg_dim', _small_3d, lambda t, d: [-1], 1e-2, 1e-5, 1e-5, _complex_types, _cpu_types, False),

zasdfgbnm added a commit that referenced this pull request May 13, 2020
ghstack-source-id: ceb4ea4
Pull Request resolved: #38382
zasdfgbnm added a commit that referenced this pull request May 14, 2020
ghstack-source-id: e3a4979
Pull Request resolved: #38382
zasdfgbnm added a commit that referenced this pull request May 15, 2020
ghstack-source-id: fc8bb5e
Pull Request resolved: #38382
zasdfgbnm added a commit that referenced this pull request May 15, 2020
ghstack-source-id: 0b21bdd
Pull Request resolved: #38382
@zasdfgbnm
Copy link
Copy Markdown
Collaborator Author

@pytorchbot retest this please

zasdfgbnm added a commit that referenced this pull request May 15, 2020
ghstack-source-id: b7a7413
Pull Request resolved: #38382
@anjali411
Copy link
Copy Markdown
Contributor

great job fixing the rocm bugs @zasdfgbnm :D

zasdfgbnm added a commit that referenced this pull request May 15, 2020
ghstack-source-id: 7134f1f
Pull Request resolved: #38382
Comment thread test/test_torch.py

_float_types_no_half = [torch.float, torch.double]

_complex_types = [torch.cfloat, torch.cdouble]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ha! well I added _complex types here https://github.com/pytorch/pytorch/pull/38400/files#diff-9996665f82f52030836eb8657057cfadR17295 but I don't wanna block this PR because of this. let's remove it in the subsequent PR

@zasdfgbnm zasdfgbnm deleted the gh/zasdfgbnm/81/head branch May 16, 2020 02:52
@facebook-github-bot
Copy link
Copy Markdown
Contributor

@anjali411 merged this pull request in 83df3be.

laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
Summary: Pull Request resolved: pytorch#38382

Test Plan: Imported from OSS

Differential Revision: D21600127

Pulled By: anjali411

fbshipit-source-id: c5338ab10bdcebe4a281b03f78e6f2063186bc32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: complex Related to complex number support in PyTorch open source

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants