|
4 | 4 | #include <ATen/native/cuda/zmath.cuh> |
5 | 5 | #include <ATen/native/TensorIterator.h> |
6 | 6 | #include <ATen/native/BinaryOps.h> |
7 | | -#include <c10/macros/Macros.h> |
8 | | - |
9 | 7 |
|
10 | 8 | // NOTE: CUDA on Windows requires that the enclosing function |
11 | 9 | // of a __device__ lambda not have internal linkage. |
12 | 10 |
|
13 | 11 | namespace at { namespace native { |
14 | 12 |
|
15 | | -void add_kernel_cuda(TensorIterator& iter, Scalar alpha_scalar) { |
16 | | - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kHalf, kBool, kBFloat16, iter.common_dtype(), "add_cuda/sub_cuda", [&]() { |
17 | | - auto alpha = alpha_scalar.to<scalar_t>(); |
18 | | - gpu_kernel_with_scalars(iter, [alpha]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { |
19 | | - return a + alpha * b; |
20 | | - }); |
21 | | - }); |
22 | | -} |
23 | | - |
24 | | -static void sub_kernel_cuda(TensorIterator& iter, Scalar alpha_scalar) { |
25 | | - add_kernel_cuda(iter, -alpha_scalar); |
26 | | -} |
27 | | - |
28 | 13 | void div_kernel_cuda(TensorIterator& iter) { |
29 | 14 | if (!isIntegralType(iter.common_dtype(), /*includeBool*/ false) && iter.is_cpu_scalar(2)) { |
30 | 15 | // optimization for floating-point types: if the second operand is a CPU |
@@ -62,33 +47,7 @@ void mul_kernel_cuda(TensorIterator& iter) { |
62 | 47 | } |
63 | 48 | } |
64 | 49 |
|
65 | | -void remainder_kernel_cuda(TensorIterator& iter) { |
66 | | - if (isIntegralType(iter.dtype(), /*includeBool*/ false)) { |
67 | | - AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "remainder_cuda", [&]() { |
68 | | - gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { |
69 | | - scalar_t r = a % b; |
70 | | - if ((r != 0) && ((r < 0) != (b < 0))) { |
71 | | - r += b; |
72 | | - } |
73 | | - return r; |
74 | | - }); |
75 | | - }); |
76 | | - } else { |
77 | | - AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "remainder_cuda", [&]() { |
78 | | - gpu_kernel_with_scalars(iter, |
79 | | - []GPU_LAMBDA(scalar_t a, scalar_t b) __ubsan_ignore_float_divide_by_zero__ -> scalar_t { |
80 | | - auto mod = ::fmod(a, b); |
81 | | - if ((mod != 0) && ((b < 0) != (mod < 0))) mod += b; |
82 | | - return mod; |
83 | | - }); |
84 | | - }); |
85 | | - } |
86 | | -} |
87 | | - |
88 | | -REGISTER_DISPATCH(add_stub, &add_kernel_cuda); |
89 | | -REGISTER_DISPATCH(sub_stub, &sub_kernel_cuda); |
90 | 50 | REGISTER_DISPATCH(div_stub, &div_kernel_cuda); |
91 | 51 | REGISTER_DISPATCH(mul_stub, &mul_kernel_cuda); |
92 | | -REGISTER_DISPATCH(remainder_stub, &remainder_kernel_cuda); |
93 | 52 |
|
94 | 53 | }} // namespace at::native |
0 commit comments