Skip to content

Commit 3657a7b

Browse files
committed
Update on "[quant][graphmode] quantization support for aten::add"
Summary: This is only for Tensor - Tensor add, we'll need to support Tensor - Scalar add in a different way, since that does not need observer Test Plan: python test/test_jit.py Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D20519607](https://our.internmc.facebook.com/intern/diff/D20519607) [ghstack-poisoned]
2 parents 6055d6a + 90c0cac commit 3657a7b

11 files changed

Lines changed: 763 additions & 447 deletions

File tree

aten/src/ATen/cuda/detail/OffsetCalculator.cuh

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,15 @@ constexpr int MAX_DIMS = 25;
1717

1818
template <int NARGS, typename index_t = uint32_t>
1919
struct OffsetCalculator {
20-
// The offset for each argument (in bytes). Wrapper around fixed-size array.
21-
using offset_type = at::detail::Array<index_t, NARGS>;
20+
// The offset for each argument. Wrapper around fixed-size array.
21+
// On CUDA, zero sized array is not allowed, so when we are handling nullary
22+
// operators, we need to create a size 1 offset to avoid compiler failure.
23+
// This size 1 offset is just a placeholder, and we will not use it.
24+
using offset_type = at::detail::Array<index_t, std::max<int>(NARGS, 1)>;
2225

23-
OffsetCalculator(int dims, const int64_t* sizes, const int64_t* const* strides) : dims(dims) {
26+
// if element_sizes is nullptr, then the strides will be in bytes, otherwise
27+
// the strides will be in # of elements.
28+
OffsetCalculator(int dims, const int64_t* sizes, const int64_t* const* strides, const int64_t* element_sizes=nullptr) : dims(dims) {
2429
TORCH_CHECK(dims <= MAX_DIMS, "tensor has too many (>", MAX_DIMS, ") dims");
2530
for (int i = 0; i < MAX_DIMS; ++i) {
2631
if (i < dims) {
@@ -29,7 +34,8 @@ struct OffsetCalculator {
2934
sizes_[i] = IntDivider<index_t>(1);
3035
}
3136
for (int arg = 0; arg < NARGS; arg++) {
32-
strides_[i][arg] = i < dims ? strides[arg][i] : 0;
37+
int64_t element_size = (element_sizes == nullptr ? 1LL : element_sizes[arg]);
38+
strides_[i][arg] = i < dims ? strides[arg][i] / element_size : 0;
3339
}
3440
}
3541
}
@@ -60,5 +66,24 @@ struct OffsetCalculator {
6066

6167
int dims;
6268
IntDivider<index_t> sizes_[MAX_DIMS];
63-
index_t strides_[MAX_DIMS][NARGS];
69+
index_t strides_[MAX_DIMS][std::max<int>(NARGS, 1)];
70+
};
71+
72+
template <int NARGS, typename index_t = uint32_t>
73+
struct TrivialOffsetCalculator {
74+
// The offset for each argument. Wrapper around fixed-size array.
75+
// The offsets are in # of elements, not in bytes.
76+
// On CUDA, zero sized array is not allowed, so when we are handling nullary
77+
// operators, we need to create a size 1 offset to avoid compiler failure.
78+
// This size 1 offset is just a placeholder, and we will not use it.
79+
using offset_type = at::detail::Array<index_t, std::max<int>(NARGS, 1)>;
80+
81+
C10_HOST_DEVICE offset_type get(index_t linear_idx) const {
82+
offset_type offsets;
83+
#pragma unroll
84+
for (int arg = 0; arg < NARGS; arg++) {
85+
offsets[arg] = linear_idx;
86+
}
87+
return offsets;
88+
}
6489
};

aten/src/ATen/native/TensorConversions.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ static inline Tensor to_impl(const Tensor& self, const TensorOptions& options, b
3333
if (self.is_non_overlapping_and_dense()) {
3434
// Copy all strides
3535
auto r = at::empty_strided(self.sizes(), self.strides(), options.memory_format(c10::nullopt));
36-
r.copy_(self);
36+
r.copy_(self, non_blocking);
3737
return r;
3838
} else {
3939
memory_format = self.suggest_memory_format();

aten/src/ATen/native/cuda/CUDALoops.cuh

Lines changed: 65 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,27 @@ static constexpr int launch_bound2 = 4;
6666

6767
namespace at { namespace native {
6868

69+
template<int N>
70+
static OffsetCalculator<N> make_input_offset_calculator(const TensorIterator& iter) {
71+
// array size can not be 0, this happens when N == 0
72+
constexpr int array_size = std::max<int>(N, 1);
73+
TORCH_INTERNAL_ASSERT(N == iter.ntensors() - 1);
74+
std::array<const int64_t*, array_size> strides;
75+
int64_t element_sizes[array_size];
76+
for (int i = 0; i < N; i++) {
77+
strides[i] = iter.strides(i + 1).data();
78+
element_sizes[i] = iter.element_size(i + 1);
79+
}
80+
return OffsetCalculator<N>(iter.ndim(), iter.shape().data(), strides.data(), element_sizes);
81+
}
82+
83+
static OffsetCalculator<1> make_output_offset_calculator(const TensorIterator& iter) {
84+
std::array<const int64_t*, 1> strides;
85+
strides[0] = iter.strides(0).data();
86+
int64_t element_size = iter.element_size(0);
87+
return OffsetCalculator<1>(iter.ndim(), iter.shape().data(), strides.data(), &element_size);
88+
}
89+
6990
// NOTE: @zasdfgbnm is currently working on rewriting the gpu loops.
7091
// Some of the old codes has been moved to namespace legacy, and
7192
// new codes will be put into namespace modern. These two namespaces
@@ -175,32 +196,37 @@ __device__ inline void elementwise_kernel_helper(func_t f, policy_t policy) {
175196
template<int vec_size, typename func_t, typename array_t>
176197
C10_LAUNCH_BOUNDS_1(num_threads)
177198
__global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
199+
using traits = function_traits<func_t>;
178200
int remaining = N - block_work_size * blockIdx.x;
179201

180202
if (remaining < block_work_size) { // if this block handles the reminder, just do a naive unrolled loop
181-
elementwise_kernel_helper(f, typename memory::policies::unroll<array_t>(data, remaining));
203+
auto input_calc = TrivialOffsetCalculator<traits::arity>();
204+
auto output_calc = TrivialOffsetCalculator<1>();
205+
auto policy = memory::policies::unroll<array_t, decltype(input_calc), decltype(output_calc)>(data, remaining, input_calc, output_calc);
206+
elementwise_kernel_helper(f, policy);
182207
} else { // if this block has a full `block_work_size` data to handle, use vectorized memory access
183-
elementwise_kernel_helper(f, typename memory::policies::template vectorized<vec_size, array_t>(data));
208+
elementwise_kernel_helper(f, memory::policies::vectorized<vec_size, array_t>(data));
184209
}
185210
}
186211

187-
template<typename func_t, typename array_t>
212+
template<typename func_t, typename array_t, typename inp_calc_t, typename out_calc_t>
188213
C10_LAUNCH_BOUNDS_1(num_threads)
189-
__global__ void unrolled_elementwise_kernel(int N, func_t f, array_t data) {
214+
__global__ void unrolled_elementwise_kernel(int N, func_t f, array_t data, inp_calc_t ic, out_calc_t oc) {
190215
int remaining = N - block_work_size * blockIdx.x;
191-
elementwise_kernel_helper(f, typename memory::policies::unroll<array_t>(data, remaining));
216+
elementwise_kernel_helper(f, memory::policies::unroll<array_t, inp_calc_t, out_calc_t>(data, remaining, ic, oc));
192217
}
193218

194-
// TODO (@zasdfgbnm): this function assume trivial 1d and no dynamic casting
219+
// this function assume trivial 1d and no dynamic casting
195220
template<typename func_t, typename array_t>
196-
static void launch_kernel(int64_t N, const func_t& f, array_t data) {
197-
TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits<int32_t>::max());
198-
if (N == 0) {
199-
return;
200-
}
221+
static inline void launch_vectorized_kernel(int64_t N, const func_t& f, array_t data) {
222+
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
223+
using traits = function_traits<func_t>;
201224
int64_t grid = (N + block_work_size - 1) / block_work_size;
202225
auto stream = at::cuda::getCurrentCUDAStream();
203226
int vec_size = memory::can_vectorize_up_to<func_t>(data);
227+
auto input_calc = TrivialOffsetCalculator<traits::arity>();
228+
auto output_calc = TrivialOffsetCalculator<1>();
229+
204230
switch (vec_size) {
205231
case 4:
206232
vectorized_elementwise_kernel<4, func_t, array_t><<<grid, num_threads, 0, stream>>>(N, f, data);
@@ -209,14 +235,23 @@ static void launch_kernel(int64_t N, const func_t& f, array_t data) {
209235
vectorized_elementwise_kernel<2, func_t, array_t><<<grid, num_threads, 0, stream>>>(N, f, data);
210236
break;
211237
case 1:
212-
unrolled_elementwise_kernel<func_t, array_t><<<grid, num_threads, 0, stream>>>(N, f, data);
238+
unrolled_elementwise_kernel<func_t, array_t><<<grid, num_threads, 0, stream>>>(N, f, data, input_calc, output_calc);
213239
break;
214240
default:
215241
TORCH_INTERNAL_ASSERT(false, "Unexpected vectorization size");
216242
}
217243
AT_CUDA_CHECK(cudaGetLastError());
218244
}
219245

246+
template<typename func_t, typename array_t, typename inp_calc_t, typename out_calc_t>
247+
static inline void launch_unrolled_kernel(int64_t N, const func_t& f, array_t data, inp_calc_t ic, out_calc_t oc) {
248+
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
249+
int64_t grid = (N + block_work_size - 1) / block_work_size;
250+
auto stream = at::cuda::getCurrentCUDAStream();
251+
unrolled_elementwise_kernel<func_t, array_t><<<grid, num_threads, 0, stream>>>(N, f, data, ic, oc);
252+
AT_CUDA_CHECK(cudaGetLastError());
253+
}
254+
220255
} // namespace modern
221256

222257

@@ -234,12 +269,29 @@ void gpu_kernel_impl(TensorIterator& iter, const func_t& f) {
234269
data[i] = (char*)iter.data_ptr(i);
235270
}
236271

272+
int64_t numel = iter.numel();
273+
274+
bool contiguous = iter.is_contiguous();
275+
bool dynamic_casting = needs_dynamic_casting<func_t>::check(iter);
276+
277+
if (contiguous && !dynamic_casting) {
278+
modern::launch_vectorized_kernel(numel, f, data);
279+
return;
280+
}
281+
282+
if (!dynamic_casting) {
283+
// !contiguous
284+
auto input_offset_calculator = make_input_offset_calculator<traits::arity>(iter);
285+
auto output_offset_calculator = make_output_offset_calculator(iter);
286+
modern::launch_unrolled_kernel(numel, f, data, input_offset_calculator, output_offset_calculator);
287+
return;
288+
}
289+
237290
at::detail::Array<ScalarType, ntensors> dtypes;
238291
for (int i = 0; i < ntensors; i++) {
239292
dtypes[i] = iter.tensor(i).scalar_type();
240293
}
241294

242-
int64_t numel = iter.numel();
243295
if (iter.is_trivial_1d()) {
244296
auto inner_strides = iter.get_inner_strides();
245297
at::detail::Array<int, ntensors> strides;
@@ -253,8 +305,6 @@ void gpu_kernel_impl(TensorIterator& iter, const func_t& f) {
253305
arg0_t result = legacy::invoke(f, &data.data[1], &strides.data[1], &dtypes.data[1], idx);
254306
c10::cast_and_store<arg0_t>(dtypes[0], out, result);
255307
});
256-
} else if (iter.has_contiguous_first_dim()) {
257-
modern::launch_kernel(numel, f, data);
258308
} else {
259309
legacy::launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
260310
arg0_t* out = (arg0_t*)(data[0] + strides[0] * idx);

aten/src/ATen/native/cuda/DistributionTemplates.h

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
#pragma once
22

33
#include <ATen/Dispatch.h>
4-
#include <ATen/native/cuda/Loops.cuh>
54
#include <ATen/native/TensorIterator.h>
65
#include <c10/util/Half.h>
6+
#include <ATen/cuda/detail/OffsetCalculator.cuh>
77

88
#include <curand.h>
99
#include <curand_kernel.h>
@@ -77,6 +77,16 @@ __global__ void distribution_elementwise_grid_stride_kernel(int numel,
7777
}
7878
}
7979

80+
template<int N>
81+
static OffsetCalculator<N> make_offset_calculator(const at::TensorIterator& iter) {
82+
AT_ASSERT(N == iter.ntensors());
83+
std::array<const int64_t*, N> strides;
84+
for (int i = 0; i < N; i++) {
85+
strides[i] = iter.strides(i).data();
86+
}
87+
return OffsetCalculator<N>(iter.ndim(), iter.shape().data(), strides.data());
88+
}
89+
8090
/**
8191
* distribution_nullary_kernel is analogous to gpu_kernel in
8292
* ATen/native/cuda/Loops.cuh. Like gpu_kernel, it uses
@@ -144,7 +154,7 @@ void distribution_nullary_kernel(at::TensorIterator& iter,
144154
}
145155
);
146156
} else {
147-
auto offset_calc = at::native::legacy::make_offset_calculator<1>(iter);
157+
auto offset_calc = make_offset_calculator<1>(iter);
148158
distribution_elementwise_grid_stride_kernel<accscalar_t, unroll_factor><<<grid, block, 0, stream>>>(
149159
numel,
150160
rng_engine_inputs,

aten/src/ATen/native/cuda/MemoryAccess.cuh

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <c10/util/Exception.h>
66
#include <c10/macros/Macros.h>
77
#include <ATen/detail/FunctionTraits.h>
8+
#include <ATen/cuda/detail/OffsetCalculator.cuh>
89

910
// References:
1011
// https://devblogs.nvidia.com/cuda-pro-tip-increase-performance-with-vectorized-memory-access/
@@ -44,8 +45,11 @@ struct static_unroll<func, end, end> {
4445
static inline C10_HOST_DEVICE void with_args(Args... args) {}
4546
};
4647

48+
// helper structs to be used with static_unroll to load arguments
49+
// one by one
50+
4751
template<int arg_index>
48-
struct load_with_policy {
52+
struct vectorized_load_helper {
4953
template <typename args_t, typename policy_t>
5054
static __device__ void apply(policy_t &self, args_t *args, int idx) {
5155
using arg_t = std::tuple_element_t<arg_index, args_t>;
@@ -57,6 +61,18 @@ struct load_with_policy {
5761
}
5862
};
5963

64+
template<int arg_index>
65+
struct unroll_load_helper {
66+
template <typename args_t, typename policy_t, typename offset_t>
67+
static __device__ void apply(policy_t &self, args_t *args, offset_t offset, int j) {
68+
using arg_t = std::tuple_element_t<arg_index, args_t>;
69+
// `data` hold the data_ptr for tensors [output, input0, input1, ...], so we
70+
// need a +1 offset to get the input
71+
auto ptr = reinterpret_cast<arg_t *>(self.data[arg_index + 1]) + offset[arg_index];
72+
std::get<arg_index>(args[j]) = *ptr;
73+
}
74+
};
75+
6076
} // namespace detail
6177

6278
// aligned vector generates vectorized load/store on CUDA
@@ -69,37 +85,37 @@ namespace policies {
6985

7086
// Assumption:
7187
// all tensors are contiguous, that is: stride == sizeof(type) for all tensors
72-
template<typename data_t>
88+
template<typename data_t, typename inp_calc_t, typename out_calc_t>
7389
struct unroll {
7490

7591
data_t data;
7692
int remaining;
93+
inp_calc_t input_offset_calculator;
94+
out_calc_t output_offset_calculator;
7795

78-
__device__ unroll(data_t data, int remaining): data(data), remaining(remaining) {}
96+
__device__ unroll(data_t data, int remaining, inp_calc_t ic, out_calc_t oc):
97+
data(data), remaining(remaining), input_offset_calculator(ic), output_offset_calculator(oc) {}
7998

8099
__device__ inline bool check_inbounds(int thread_work_elem) {
81100
return ((threadIdx.x + thread_work_elem*num_threads) < remaining);
82101
}
83102

84-
template<typename accessor_t, typename scalar_t>
85-
__device__ inline void load_single_arg(accessor_t to, scalar_t *from) {
103+
template<typename args_t>
104+
__device__ inline void load(args_t *args, int idx) {
105+
constexpr int arity = std::tuple_size<args_t>::value;
86106
int thread_idx = threadIdx.x;
87107
#pragma unroll
88108
for (int i = 0; i < thread_work_size; i++) {
89109
if (thread_idx >= remaining) {
90110
return;
91111
}
92-
to(i) = from[thread_idx];
112+
int linear_idx = thread_idx + block_work_size * idx;
113+
auto offset = input_offset_calculator.get(linear_idx);
114+
detail::static_unroll<detail::unroll_load_helper, arity>::with_args(*this, args, offset, i);
93115
thread_idx += num_threads;
94116
}
95117
}
96118

97-
template<typename args_t>
98-
__device__ inline void load(args_t *args, int idx) {
99-
constexpr int arity = std::tuple_size<args_t>::value;
100-
detail::static_unroll<detail::load_with_policy, arity>::with_args(*this, args, idx);
101-
}
102-
103119
template<typename scalar_t>
104120
__device__ inline void store(scalar_t *from, int idx) {
105121
int thread_idx = threadIdx.x;
@@ -109,7 +125,10 @@ struct unroll {
109125
if (thread_idx >= remaining) {
110126
return;
111127
}
112-
to[thread_idx] = from[i];
128+
int linear_idx = thread_idx + block_work_size * idx;
129+
int offset = output_offset_calculator.get(linear_idx)[0];
130+
scalar_t *to = reinterpret_cast<scalar_t *>(data[0]) + offset;
131+
*to = from[i];
113132
thread_idx += num_threads;
114133
}
115134
}
@@ -153,7 +172,7 @@ struct vectorized {
153172
template<typename args_t>
154173
__device__ inline void load(args_t *args, int idx) {
155174
constexpr int arity = std::tuple_size<args_t>::value;
156-
detail::static_unroll<detail::load_with_policy, arity>::with_args(*this, args, idx);
175+
detail::static_unroll<detail::vectorized_load_helper, arity>::with_args(*this, args, idx);
157176
}
158177

159178
template<typename scalar_t>

0 commit comments

Comments
 (0)