@@ -66,6 +66,27 @@ static constexpr int launch_bound2 = 4;
6666
6767namespace at { namespace native {
6868
69+ template <int N>
70+ static OffsetCalculator<N> make_input_offset_calculator (const TensorIterator& iter) {
71+ // array size can not be 0, this happens when N == 0
72+ constexpr int array_size = std::max<int >(N, 1 );
73+ TORCH_INTERNAL_ASSERT (N == iter.ntensors () - 1 );
74+ std::array<const int64_t *, array_size> strides;
75+ int64_t element_sizes[array_size];
76+ for (int i = 0 ; i < N; i++) {
77+ strides[i] = iter.strides (i + 1 ).data ();
78+ element_sizes[i] = iter.element_size (i + 1 );
79+ }
80+ return OffsetCalculator<N>(iter.ndim (), iter.shape ().data (), strides.data (), element_sizes);
81+ }
82+
83+ static OffsetCalculator<1 > make_output_offset_calculator (const TensorIterator& iter) {
84+ std::array<const int64_t *, 1 > strides;
85+ strides[0 ] = iter.strides (0 ).data ();
86+ int64_t element_size = iter.element_size (0 );
87+ return OffsetCalculator<1 >(iter.ndim (), iter.shape ().data (), strides.data (), &element_size);
88+ }
89+
6990// NOTE: @zasdfgbnm is currently working on rewriting the gpu loops.
7091// Some of the old codes has been moved to namespace legacy, and
7192// new codes will be put into namespace modern. These two namespaces
@@ -175,32 +196,37 @@ __device__ inline void elementwise_kernel_helper(func_t f, policy_t policy) {
175196template <int vec_size, typename func_t , typename array_t >
176197C10_LAUNCH_BOUNDS_1 (num_threads)
177198__global__ void vectorized_elementwise_kernel (int N, func_t f, array_t data) {
199+ using traits = function_traits<func_t >;
178200 int remaining = N - block_work_size * blockIdx .x ;
179201
180202 if (remaining < block_work_size) { // if this block handles the reminder, just do a naive unrolled loop
181- elementwise_kernel_helper (f, typename memory::policies::unroll<array_t >(data, remaining));
203+ auto input_calc = TrivialOffsetCalculator<traits::arity>();
204+ auto output_calc = TrivialOffsetCalculator<1 >();
205+ auto policy = memory::policies::unroll<array_t , decltype (input_calc), decltype (output_calc)>(data, remaining, input_calc, output_calc);
206+ elementwise_kernel_helper (f, policy);
182207 } else { // if this block has a full `block_work_size` data to handle, use vectorized memory access
183- elementwise_kernel_helper (f, typename memory::policies::template vectorized<vec_size, array_t >(data));
208+ elementwise_kernel_helper (f, memory::policies::vectorized<vec_size, array_t >(data));
184209 }
185210}
186211
187- template <typename func_t , typename array_t >
212+ template <typename func_t , typename array_t , typename inp_calc_t , typename out_calc_t >
188213C10_LAUNCH_BOUNDS_1 (num_threads)
189- __global__ void unrolled_elementwise_kernel (int N, func_t f, array_t data) {
214+ __global__ void unrolled_elementwise_kernel (int N, func_t f, array_t data, inp_calc_t ic, out_calc_t oc ) {
190215 int remaining = N - block_work_size * blockIdx .x ;
191- elementwise_kernel_helper (f, typename memory::policies::unroll<array_t >(data, remaining));
216+ elementwise_kernel_helper (f, memory::policies::unroll<array_t , inp_calc_t , out_calc_t >(data, remaining, ic, oc ));
192217}
193218
194- // TODO (@zasdfgbnm): this function assume trivial 1d and no dynamic casting
219+ // this function assume trivial 1d and no dynamic casting
195220template <typename func_t , typename array_t >
196- static void launch_kernel (int64_t N, const func_t & f, array_t data) {
197- TORCH_INTERNAL_ASSERT (N >= 0 && N <= std::numeric_limits<int32_t >::max ());
198- if (N == 0 ) {
199- return ;
200- }
221+ static inline void launch_vectorized_kernel (int64_t N, const func_t & f, array_t data) {
222+ TORCH_INTERNAL_ASSERT (N > 0 && N <= std::numeric_limits<int32_t >::max ());
223+ using traits = function_traits<func_t >;
201224 int64_t grid = (N + block_work_size - 1 ) / block_work_size;
202225 auto stream = at::cuda::getCurrentCUDAStream ();
203226 int vec_size = memory::can_vectorize_up_to<func_t >(data);
227+ auto input_calc = TrivialOffsetCalculator<traits::arity>();
228+ auto output_calc = TrivialOffsetCalculator<1 >();
229+
204230 switch (vec_size) {
205231 case 4 :
206232 vectorized_elementwise_kernel<4 , func_t , array_t ><<<grid, num_threads, 0 , stream>>> (N, f, data);
@@ -209,14 +235,23 @@ static void launch_kernel(int64_t N, const func_t& f, array_t data) {
209235 vectorized_elementwise_kernel<2 , func_t , array_t ><<<grid, num_threads, 0 , stream>>> (N, f, data);
210236 break ;
211237 case 1 :
212- unrolled_elementwise_kernel<func_t , array_t ><<<grid, num_threads, 0 , stream>>> (N, f, data);
238+ unrolled_elementwise_kernel<func_t , array_t ><<<grid, num_threads, 0 , stream>>> (N, f, data, input_calc, output_calc );
213239 break ;
214240 default :
215241 TORCH_INTERNAL_ASSERT (false , " Unexpected vectorization size" );
216242 }
217243 AT_CUDA_CHECK (cudaGetLastError ());
218244}
219245
246+ template <typename func_t , typename array_t , typename inp_calc_t , typename out_calc_t >
247+ static inline void launch_unrolled_kernel (int64_t N, const func_t & f, array_t data, inp_calc_t ic, out_calc_t oc) {
248+ TORCH_INTERNAL_ASSERT (N > 0 && N <= std::numeric_limits<int32_t >::max ());
249+ int64_t grid = (N + block_work_size - 1 ) / block_work_size;
250+ auto stream = at::cuda::getCurrentCUDAStream ();
251+ unrolled_elementwise_kernel<func_t , array_t ><<<grid, num_threads, 0 , stream>>> (N, f, data, ic, oc);
252+ AT_CUDA_CHECK (cudaGetLastError ());
253+ }
254+
220255} // namespace modern
221256
222257
@@ -234,12 +269,29 @@ void gpu_kernel_impl(TensorIterator& iter, const func_t& f) {
234269 data[i] = (char *)iter.data_ptr (i);
235270 }
236271
272+ int64_t numel = iter.numel ();
273+
274+ bool contiguous = iter.is_contiguous ();
275+ bool dynamic_casting = needs_dynamic_casting<func_t >::check (iter);
276+
277+ if (contiguous && !dynamic_casting) {
278+ modern::launch_vectorized_kernel (numel, f, data);
279+ return ;
280+ }
281+
282+ if (!dynamic_casting) {
283+ // !contiguous
284+ auto input_offset_calculator = make_input_offset_calculator<traits::arity>(iter);
285+ auto output_offset_calculator = make_output_offset_calculator (iter);
286+ modern::launch_unrolled_kernel (numel, f, data, input_offset_calculator, output_offset_calculator);
287+ return ;
288+ }
289+
237290 at::detail::Array<ScalarType, ntensors> dtypes;
238291 for (int i = 0 ; i < ntensors; i++) {
239292 dtypes[i] = iter.tensor (i).scalar_type ();
240293 }
241294
242- int64_t numel = iter.numel ();
243295 if (iter.is_trivial_1d ()) {
244296 auto inner_strides = iter.get_inner_strides ();
245297 at::detail::Array<int , ntensors> strides;
@@ -253,8 +305,6 @@ void gpu_kernel_impl(TensorIterator& iter, const func_t& f) {
253305 arg0_t result = legacy::invoke (f, &data.data [1 ], &strides.data [1 ], &dtypes.data [1 ], idx);
254306 c10::cast_and_store<arg0_t >(dtypes[0 ], out, result);
255307 });
256- } else if (iter.has_contiguous_first_dim ()) {
257- modern::launch_kernel (numel, f, data);
258308 } else {
259309 legacy::launch_kernel<launch_size_1d, 1 >(numel, [=]GPU_LAMBDA (int idx) {
260310 arg0_t * out = (arg0_t *)(data[0 ] + strides[0 ] * idx);
0 commit comments