Skip to content

Commit 899a075

Browse files
gchananfacebook-github-bot
authored andcommitted
Split up BinaryAritmeticKernel.cu to speed up compilation time. (#38263)
Summary: Pull Request resolved: #38263 On my machine, compilation went from 4m8sec to the maximum of the files being compiled in 2m22sec. Test Plan: Imported from OSS Differential Revision: D21508985 Pulled By: gchanan fbshipit-source-id: 2917cd5f30c6b31229053cada93c95e3a27ab29a
1 parent d86de91 commit 899a075

5 files changed

Lines changed: 66 additions & 43 deletions

File tree

aten/src/ATen/native/TensorIterator.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <ATen/ExpandUtils.h>
55
#include <ATen/Parallel.h>
66
#include <ATen/native/TypeProperties.h>
7+
#include <ATen/MemoryOverlap.h>
78

89
namespace at {
910

aten/src/ATen/native/TensorIterator.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
#include <c10/util/TypeCast.h>
77
#include <ATen/core/Range.h>
88
#include <bitset>
9-
#include <c10/util/Optional.h>
10-
#include <ATen/MemoryOverlap.h>
119
#include <ATen/NamedTensorUtils.h>
1210
#include <ATen/Parallel.h>
1311

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#include <ATen/Dispatch.h>
2+
#include <ATen/native/DispatchStub.h>
3+
#include <ATen/native/cuda/Loops.cuh>
4+
#include <ATen/native/BinaryOps.h>
5+
6+
// NOTE: CUDA on Windows requires that the enclosing function
7+
// of a __device__ lambda not have internal linkage.
8+
9+
namespace at { namespace native {
10+
11+
void add_kernel_cuda(TensorIterator& iter, Scalar alpha_scalar) {
12+
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kHalf, kBool, kBFloat16, iter.common_dtype(), "add_cuda/sub_cuda", [&]() {
13+
auto alpha = alpha_scalar.to<scalar_t>();
14+
gpu_kernel_with_scalars(iter, [alpha]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
15+
return a + alpha * b;
16+
});
17+
});
18+
}
19+
20+
static void sub_kernel_cuda(TensorIterator& iter, Scalar alpha_scalar) {
21+
add_kernel_cuda(iter, -alpha_scalar);
22+
}
23+
24+
REGISTER_DISPATCH(add_stub, &add_kernel_cuda);
25+
REGISTER_DISPATCH(sub_stub, &sub_kernel_cuda);
26+
27+
}} // namespace at::native

aten/src/ATen/native/cuda/BinaryArithmeticKernel.cu renamed to aten/src/ATen/native/cuda/BinaryMulDivKernel.cu

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -4,27 +4,12 @@
44
#include <ATen/native/cuda/zmath.cuh>
55
#include <ATen/native/TensorIterator.h>
66
#include <ATen/native/BinaryOps.h>
7-
#include <c10/macros/Macros.h>
8-
97

108
// NOTE: CUDA on Windows requires that the enclosing function
119
// of a __device__ lambda not have internal linkage.
1210

1311
namespace at { namespace native {
1412

15-
void add_kernel_cuda(TensorIterator& iter, Scalar alpha_scalar) {
16-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kHalf, kBool, kBFloat16, iter.common_dtype(), "add_cuda/sub_cuda", [&]() {
17-
auto alpha = alpha_scalar.to<scalar_t>();
18-
gpu_kernel_with_scalars(iter, [alpha]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
19-
return a + alpha * b;
20-
});
21-
});
22-
}
23-
24-
static void sub_kernel_cuda(TensorIterator& iter, Scalar alpha_scalar) {
25-
add_kernel_cuda(iter, -alpha_scalar);
26-
}
27-
2813
void div_kernel_cuda(TensorIterator& iter) {
2914
if (!isIntegralType(iter.common_dtype(), /*includeBool*/ false) && iter.is_cpu_scalar(2)) {
3015
// optimization for floating-point types: if the second operand is a CPU
@@ -62,33 +47,7 @@ void mul_kernel_cuda(TensorIterator& iter) {
6247
}
6348
}
6449

65-
void remainder_kernel_cuda(TensorIterator& iter) {
66-
if (isIntegralType(iter.dtype(), /*includeBool*/ false)) {
67-
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "remainder_cuda", [&]() {
68-
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
69-
scalar_t r = a % b;
70-
if ((r != 0) && ((r < 0) != (b < 0))) {
71-
r += b;
72-
}
73-
return r;
74-
});
75-
});
76-
} else {
77-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "remainder_cuda", [&]() {
78-
gpu_kernel_with_scalars(iter,
79-
[]GPU_LAMBDA(scalar_t a, scalar_t b) __ubsan_ignore_float_divide_by_zero__ -> scalar_t {
80-
auto mod = ::fmod(a, b);
81-
if ((mod != 0) && ((b < 0) != (mod < 0))) mod += b;
82-
return mod;
83-
});
84-
});
85-
}
86-
}
87-
88-
REGISTER_DISPATCH(add_stub, &add_kernel_cuda);
89-
REGISTER_DISPATCH(sub_stub, &sub_kernel_cuda);
9050
REGISTER_DISPATCH(div_stub, &div_kernel_cuda);
9151
REGISTER_DISPATCH(mul_stub, &mul_kernel_cuda);
92-
REGISTER_DISPATCH(remainder_stub, &remainder_kernel_cuda);
9352

9453
}} // namespace at::native
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#include <ATen/Dispatch.h>
2+
#include <ATen/native/DispatchStub.h>
3+
#include <ATen/native/cuda/Loops.cuh>
4+
#include <ATen/native/cuda/zmath.cuh>
5+
#include <ATen/native/TensorIterator.h>
6+
#include <ATen/native/BinaryOps.h>
7+
8+
// NOTE: CUDA on Windows requires that the enclosing function
9+
// of a __device__ lambda not have internal linkage.
10+
11+
namespace at { namespace native {
12+
13+
void remainder_kernel_cuda(TensorIterator& iter) {
14+
if (isIntegralType(iter.dtype(), /*includeBool*/ false)) {
15+
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "remainder_cuda", [&]() {
16+
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
17+
scalar_t r = a % b;
18+
if ((r != 0) && ((r < 0) != (b < 0))) {
19+
r += b;
20+
}
21+
return r;
22+
});
23+
});
24+
} else {
25+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "remainder_cuda", [&]() {
26+
gpu_kernel_with_scalars(iter,
27+
[]GPU_LAMBDA(scalar_t a, scalar_t b) __ubsan_ignore_float_divide_by_zero__ -> scalar_t {
28+
auto mod = ::fmod(a, b);
29+
if ((mod != 0) && ((b < 0) != (mod < 0))) mod += b;
30+
return mod;
31+
});
32+
});
33+
}
34+
}
35+
36+
REGISTER_DISPATCH(remainder_stub, &remainder_kernel_cuda);
37+
38+
}} // namespace at::native

0 commit comments

Comments
 (0)