Skip to content

Commit ead6fe4

Browse files
malfetcsarofeen
authored andcommitted
Use explicit templates in gpu_kernel_with_scalars (pytorch#40992)
Summary: This trick should have no effect on performance, but it reduces size of kernels using the template by 10% For example, sizeof(BinaryMulDivKernel.cu.o) compiled by CUDA-10.1 toolchain for sm_75 before the change was 4.2Mb, after 3.8Mb Pull Request resolved: pytorch#40992 Differential Revision: D22398733 Pulled By: malfet fbshipit-source-id: 6576f4da00dc5fc2575b2313577f52c6571d5e6f
1 parent 2a04414 commit ead6fe4

1 file changed

Lines changed: 36 additions & 13 deletions

File tree

aten/src/ATen/native/cuda/Loops.cuh

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ namespace at { namespace native {
5555

5656
template <typename func_t>
5757
void gpu_kernel(TensorIterator& iter, const func_t& f) {
58-
ASSERT_HOST_DEVICE_LAMBDA(func_t);
5958

6059
for (int arg = 0; arg < iter.ntensors(); arg++) {
6160
TORCH_INTERNAL_ASSERT(iter.device(arg).is_cuda());
@@ -75,6 +74,36 @@ void gpu_kernel(TensorIterator& iter, const func_t& f) {
7574
gpu_kernel_impl(iter, f);
7675
}
7776

77+
template<typename func_t>
78+
struct AUnaryFunctor {
79+
using traits = function_traits<func_t>;
80+
using arg1_t = typename traits::template arg<0>::type;
81+
using arg2_t = typename traits::template arg<1>::type;
82+
using return_t = typename traits::result_type;
83+
__device__ return_t operator()(arg2_t b) const {
84+
return f(a, b);
85+
}
86+
AUnaryFunctor(func_t f_, arg1_t a_): f(f_), a(a_) {}
87+
private:
88+
func_t f;
89+
arg1_t a;
90+
};
91+
92+
template<typename func_t>
93+
struct BUnaryFunctor {
94+
using traits = function_traits<func_t>;
95+
using arg1_t = typename traits::template arg<0>::type;
96+
using arg2_t = typename traits::template arg<1>::type;
97+
using return_t = typename traits::result_type;
98+
__device__ return_t operator()(arg1_t a) const {
99+
return f(a, b);
100+
}
101+
BUnaryFunctor(func_t f_, arg2_t b_): f(f_), b(b_) {}
102+
private:
103+
func_t f;
104+
arg2_t b;
105+
};
106+
78107
template <typename func_t>
79108
void gpu_kernel_with_scalars(TensorIterator& iter, const func_t& f) {
80109
ASSERT_HOST_DEVICE_LAMBDA(func_t);
@@ -85,23 +114,17 @@ void gpu_kernel_with_scalars(TensorIterator& iter, const func_t& f) {
85114
traits::arity == 2,
86115
"gpu_kernel_with_scalars only supports two input arguments");
87116

117+
using arg1_t = typename traits::template arg<0>::type;
118+
using arg2_t = typename traits::template arg<1>::type;
88119
if (iter.is_cpu_scalar(1)) {
89-
using arg1_t = typename traits::template arg<0>::type;
90-
using arg2_t = typename traits::template arg<1>::type;
91-
auto a = iter.scalar_value<arg1_t>(1);
120+
AUnaryFunctor<func_t> af(f, iter.scalar_value<arg1_t>(1));
92121
iter.remove_operand(1);
93122
const OptionalDeviceGuard device_guard(device_of(iter.tensor(1)));
94-
gpu_kernel(iter, [=]GPU_LAMBDA(arg2_t b) {
95-
return f(a, b);
96-
});
123+
gpu_kernel(iter, af);
97124
} else if (iter.is_cpu_scalar(2)) {
98-
using arg1_t = typename traits::template arg<0>::type;
99-
using arg2_t = typename traits::template arg<1>::type;
100-
auto b = iter.scalar_value<arg2_t>(2);
125+
BUnaryFunctor<func_t> bf(f, iter.scalar_value<arg2_t>(2));
101126
iter.remove_operand(2);
102-
gpu_kernel(iter, [=]GPU_LAMBDA(arg1_t a) {
103-
return f(a, b);
104-
});
127+
gpu_kernel(iter, bf);
105128
} else {
106129
gpu_kernel(iter, f);
107130
}

0 commit comments

Comments
 (0)