sum and roll on cuda for complex dtypes#37959
sum and roll on cuda for complex dtypes#37959anjali411 wants to merge 3 commits intogh/anjali411/15/basefrom
Conversation
[ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit 97868ce (more details on the Dr. CI page):
🕵️ 1 new failure recognized by patternsThe following CI failures do not appear to be due to upstream breakages:
|
[ghstack-poisoned]
[ghstack-poisoned]
| AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::Bool, in_tensor.scalar_type(), "roll_cuda", [&] { | ||
| AT_DISPATCH_ALL_TYPES_AND_C10_COMPLEX_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, | ||
| in_tensor.scalar_type(), "roll_cuda", [&] { | ||
| using value_t = typename ztype<scalar_t>::value_t; |
There was a problem hiding this comment.
Why do we need a value_t here? CPU's ztype<scalar_t>::value_t is a noop for c10::complex
There was a problem hiding this comment.
yeah forgot to remove it after replacing AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3 with AT_DISPATCH_ALL_TYPES_AND_C10_COMPLEX_AND3
| AT_DISPATCH_ALL_TYPES_AND(ScalarType::Bool, iter.dtype(), "sum_cuda", [&]() { | ||
| sum_kernel_impl<scalar_t>(iter); | ||
| AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(ScalarType::Bool, iter.dtype(), "sum_cuda", [&]() { | ||
| using value_t = typename ztype<scalar_t>::value_t; |
There was a problem hiding this comment.
I guess AT_DISPATCH_ALL_TYPES_AND_C10_COMPLEX_AND and remove the ztype will just work.
There was a problem hiding this comment.
hmm there was an issue with __shfl_up_sync for c10::complex. I'll look into it more
| auto total_dims = in_tensor.dim(); | ||
|
|
||
| AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::Bool, in_tensor.scalar_type(), "roll_cuda", [&] { | ||
| AT_DISPATCH_ALL_TYPES_AND_C10_COMPLEX_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, |
There was a problem hiding this comment.
This will conflict with #37977, whichever lands first, the other needs change.
|
Test failure looks real. |
|
|
Stack from ghstack:
Resolves #37925