Add complex support for torch.sum#38382
Add complex support for torch.sum#38382zasdfgbnm wants to merge 23 commits intogh/zasdfgbnm/81/basefrom
Conversation
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
| #endif | ||
| } | ||
|
|
||
| __device__ __forceinline__ unsigned int ACTIVE_MASK() |
There was a problem hiding this comment.
Moved from THC to ATen, with support for old CUDA version removed.
| } | ||
|
|
||
| template <typename T> | ||
| __device__ __forceinline__ c10::complex<T> WARP_SHFL_DOWN(c10::complex<T> value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff) |
There was a problem hiding this comment.
This is the key to making it work for c10::complex.
💊 CI failures summary and remediationsAs of commit d9f647f (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker. This comment has been revised 61 times. |
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
I think we should add the sum related tests in this PR so that it's fully tested before merge: |
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
|
@pytorchbot retest this please |
[ghstack-poisoned]
|
great job fixing the rocm bugs @zasdfgbnm :D |
[ghstack-poisoned]
|
|
||
| _float_types_no_half = [torch.float, torch.double] | ||
|
|
||
| _complex_types = [torch.cfloat, torch.cdouble] |
There was a problem hiding this comment.
ha! well I added _complex types here https://github.com/pytorch/pytorch/pull/38400/files#diff-9996665f82f52030836eb8657057cfadR17295 but I don't wanna block this PR because of this. let's remove it in the subsequent PR
|
@anjali411 merged this pull request in 83df3be. |
Summary: Pull Request resolved: pytorch#38382 Test Plan: Imported from OSS Differential Revision: D21600127 Pulled By: anjali411 fbshipit-source-id: c5338ab10bdcebe4a281b03f78e6f2063186bc32
Stack from ghstack:
Differential Revision: D21600127