Skip to content

Commit 0c936f9

Browse files
Mike Ruberryfacebook-github-bot
authored andcommitted
Revert D21449612: [pytorch][PR] Migrate AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3 to c10::complex
Test Plan: revert-hammer Differential Revision: D21449612 Original commit changeset: 236070946b9d fbshipit-source-id: 2de485ca18388a055f44d6caf18cf516b2288875
1 parent 0f60c8d commit 0c936f9

18 files changed

Lines changed: 74 additions & 30 deletions

aten/src/ATen/Dispatch.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,38 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
580580

581581
#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( \
582582
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
583+
[&] { \
584+
switch (TYPE) { \
585+
AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \
586+
AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \
587+
AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \
588+
AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \
589+
AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \
590+
AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \
591+
AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \
592+
AT_PRIVATE_CASE_TYPE( \
593+
at::ScalarType::ComplexFloat, std::complex<float>, __VA_ARGS__) \
594+
AT_PRIVATE_CASE_TYPE( \
595+
at::ScalarType::ComplexDouble, std::complex<double>, __VA_ARGS__) \
596+
AT_PRIVATE_CASE_TYPE( \
597+
SCALARTYPE1, \
598+
decltype(c10::impl::ScalarTypeToCPPType<SCALARTYPE1>::t), \
599+
__VA_ARGS__) \
600+
AT_PRIVATE_CASE_TYPE( \
601+
SCALARTYPE2, \
602+
decltype(c10::impl::ScalarTypeToCPPType<SCALARTYPE2>::t), \
603+
__VA_ARGS__) \
604+
AT_PRIVATE_CASE_TYPE( \
605+
SCALARTYPE3, \
606+
decltype(c10::impl::ScalarTypeToCPPType<SCALARTYPE3>::t), \
607+
__VA_ARGS__) \
608+
default: \
609+
AT_ERROR(#NAME, " not implemented for '", TYPE, "'"); \
610+
} \
611+
}()
612+
613+
#define AT_DISPATCH_ALL_TYPES_AND_C10_COMPLEX_AND3( \
614+
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
583615
[&] { \
584616
switch (TYPE) { \
585617
AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \

aten/src/ATen/native/BinaryOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ static Tensor wrapped_scalar_tensor(Scalar scalar) {
270270

271271
static void check_convert(Scalar scalar, ScalarType scalarType) {
272272
// Validate that is possible to convert scalar to tensor dtype without overflow
273-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, scalarType, "check_convert", [&]{
273+
AT_DISPATCH_ALL_TYPES_AND_C10_COMPLEX_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, scalarType, "check_convert", [&]{
274274
scalar.to<scalar_t>();
275275
});
276276
}

aten/src/ATen/native/Scalar.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ Scalar item(const Tensor& self) {
2020

2121
Scalar _local_scalar_dense_cpu(const Tensor& self) {
2222
Scalar r;
23-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
23+
AT_DISPATCH_ALL_TYPES_AND_C10_COMPLEX_AND3(
2424
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, self.scalar_type(), "_local_scalar_dense_cpu", [&] {
2525
scalar_t value = *self.data_ptr<scalar_t>();
2626
r = Scalar(value);

aten/src/ATen/native/cpu/BinaryOpsKernel.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ void ge_kernel(TensorIterator& iter) {
386386

387387
void eq_kernel(TensorIterator& iter) {
388388
if (iter.dtype() == ScalarType::Bool) {
389-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.input_dtype(), "eq_cpu", [&]() {
389+
AT_DISPATCH_ALL_TYPES_AND_C10_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.input_dtype(), "eq_cpu", [&]() {
390390
cpu_kernel(iter,
391391
[](scalar_t a, scalar_t b) -> bool {
392392
return a == b;
@@ -408,7 +408,7 @@ void eq_kernel(TensorIterator& iter) {
408408

409409
void ne_kernel(TensorIterator& iter) {
410410
if (iter.dtype() == ScalarType::Bool) {
411-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.input_dtype(), "ne_cpu", [&]() {
411+
AT_DISPATCH_ALL_TYPES_AND_C10_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.input_dtype(), "ne_cpu", [&]() {
412412
cpu_kernel(iter,
413413
[](scalar_t a, scalar_t b) -> bool {
414414
return a != b;

aten/src/ATen/native/cpu/CopyKernel.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@ static void copy_kernel(TensorIterator& iter, bool non_blocking) {
4040
});
4141
}
4242
} else {
43-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(ScalarType::Half, ScalarType::Bool, ScalarType::BFloat16, dtype, "copy_", [&] {
43+
AT_DISPATCH_ALL_TYPES_AND_C10_COMPLEX_AND3(ScalarType::Half, ScalarType::Bool, ScalarType::BFloat16, dtype, "copy_", [&] {
4444
using dest_t = scalar_t;
45-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(ScalarType::Half, ScalarType::Bool, ScalarType::BFloat16, iter.dtype(1), "copy_", [&] {
45+
AT_DISPATCH_ALL_TYPES_AND_C10_COMPLEX_AND3(ScalarType::Half, ScalarType::Bool, ScalarType::BFloat16, iter.dtype(1), "copy_", [&] {
4646
// Note (@zasdfgbnm):
4747
//
4848
// The code below can not be simplified as

aten/src/ATen/native/cpu/IndexKernel.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ void cpu_index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef
9898
}
9999

100100
void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride) {
101-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
101+
AT_DISPATCH_ALL_TYPES_AND_C10_COMPLEX_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
102102
iter.dtype(), "index_cpu", [&] {
103103
cpu_index_kernel<scalar_t>(iter, index_size, index_stride, [](char* dst, char* src, int64_t offset) {
104104
*(scalar_t*)dst = *(scalar_t*)(src + offset);
@@ -108,7 +108,7 @@ void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef inde
108108

109109
void index_put_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride, bool accumulate) {
110110
// NOTE: duplicate indices are only supported if accumulate is true.
111-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
111+
AT_DISPATCH_ALL_TYPES_AND_C10_COMPLEX_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
112112
iter.dtype(), "index_put", [&] {
113113
if (accumulate) {
114114
bool use_parallel_for = ((iter.numel() >= internal::GRAIN_SIZE) && (at::get_num_threads() > 1));

aten/src/ATen/native/cpu/ReduceOpsKernel.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ static void cumprod_cpu_kernel(Tensor& result, const Tensor& self, int64_t dim)
101101
}
102102

103103
static void sum_kernel_impl(TensorIterator& iter) {
104-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
104+
AT_DISPATCH_ALL_TYPES_AND_C10_COMPLEX_AND3(
105105
ScalarType::BFloat16, ScalarType::Half, ScalarType::Bool, iter.dtype(), "sum_cpu", [&] {
106106
binary_kernel_reduce_vec(
107107
iter, [=](scalar_t a, scalar_t b) -> scalar_t { return a + b; },

aten/src/ATen/native/cuda/BinaryArithmeticKernel.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
namespace at { namespace native {
1414

1515
void add_kernel_cuda(TensorIterator& iter, Scalar alpha_scalar) {
16-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kHalf, kBool, kBFloat16, iter.common_dtype(), "add_cuda/sub_cuda", [&]() {
16+
AT_DISPATCH_ALL_TYPES_AND_C10_COMPLEX_AND3(kHalf, kBool, kBFloat16, iter.common_dtype(), "add_cuda/sub_cuda", [&]() {
1717
auto alpha = alpha_scalar.to<scalar_t>();
1818
gpu_kernel_with_scalars(iter, [alpha]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
1919
return a + alpha * b;

aten/src/ATen/native/cuda/CUDAScalar.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@ namespace native {
1414
Scalar _local_scalar_dense_cuda(const Tensor& self) {
1515
Scalar r;
1616
#if HIP_VERSION >= 301
17-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
17+
AT_DISPATCH_ALL_TYPES_AND_C10_COMPLEX_AND3(
1818
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, self.scalar_type(), "_local_scalar_dense_cuda", [&] {
1919
scalar_t value;
2020
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
2121
AT_CUDA_CHECK(hipMemcpyWithStream(&value, self.data_ptr<scalar_t>(), sizeof(scalar_t), cudaMemcpyDeviceToHost, stream));
2222
r = Scalar(value);
2323
});
2424
#else
25-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
25+
AT_DISPATCH_ALL_TYPES_AND_C10_COMPLEX_AND3(
2626
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, self.scalar_type(), "_local_scalar_dense_cuda", [&] {
2727
scalar_t value;
2828
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

aten/src/ATen/native/cuda/CompareEQKernel.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
namespace at { namespace native {
1212

1313
void eq_kernel_cuda(TensorIterator& iter) {
14-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kHalf, kBFloat16, kBool, iter.common_dtype(), "eq_cuda", [&]() {
14+
AT_DISPATCH_ALL_TYPES_AND_C10_COMPLEX_AND3(kHalf, kBFloat16, kBool, iter.common_dtype(), "eq_cuda", [&]() {
1515
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "eq_cuda", [&] {
1616
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
1717
return a == b;

0 commit comments

Comments
 (0)