@@ -102,13 +102,16 @@ struct NaiveScheduler {
102102 }
103103
104104 template <typename FUNC>
105- __device__ __forceinline__ static void execute (int subwarps_per_block, FUNC fn) {
105+ __device__ __forceinline__ static void execute (int subwarps_per_block, int hidden_size_num_groups, FUNC fn) {
106106 const int local_group_id = threadIdx .x / THREADS_PER_SUBWARP;
107107 const int lane_id = threadIdx .x % THREADS_PER_SUBWARP;
108108 const int block_group_id = blockIdx .x * subwarps_per_block;
109109 const int group_id = block_group_id + local_group_id;
110110
111- fn (group_id, lane_id);
111+ const int token_idx = group_id / hidden_size_num_groups;
112+ const int group_start_hidden_idx = group_id % hidden_size_num_groups;
113+
114+ fn (token_idx, group_start_hidden_idx, lane_id);
112115 }
113116};
114117
@@ -125,93 +128,92 @@ __global__ void per_token_group_quant_8bit_kernel(
125128 scale_packed_t * __restrict__ output_s,
126129 const int group_size,
127130 const int subwarps_per_block,
128- // TODO can remove?
129- const int scale_hidden_size = 0 ,
130- const int scale_hidden_stride = 0 ) {
131+ const int hidden_size_num_groups,
132+ const int scale_hidden_stride) {
131133 using dst_dtype_info = DtypeInfo<DST_DTYPE>;
132134 using scale_element_t = std::conditional_t <SCALE_UE8M0, uint8_t , float >;
133135 static_assert (sizeof (scale_packed_t ) % sizeof (scale_element_t ) == 0 );
134136
135- SCHEDULER::execute (subwarps_per_block, [&](int group_id, int lane_id) {
136- constexpr uint32_t INPUT_PRIMARY_VEC_SIZE = INPUT_PRIMARY_VEC_NUM_BYTES / sizeof (T);
137- constexpr uint32_t INPUT_PRIMARY_INT4_SIZE = INPUT_PRIMARY_VEC_NUM_BYTES / sizeof (int4 );
138-
139- int4 input_primary_int4[INPUT_PRIMARY_INT4_SIZE];
140- T* input_primary_vec = reinterpret_cast <T*>(input_primary_int4);
141- static_assert (sizeof (input_primary_vec[0 ]) * INPUT_PRIMARY_VEC_SIZE == sizeof (input_primary_int4));
137+ SCHEDULER::execute (
138+ subwarps_per_block, hidden_size_num_groups, [&](int token_idx, int group_start_hidden_idx, int lane_id) {
139+ constexpr uint32_t INPUT_PRIMARY_VEC_SIZE = INPUT_PRIMARY_VEC_NUM_BYTES / sizeof (T);
140+ constexpr uint32_t INPUT_PRIMARY_INT4_SIZE = INPUT_PRIMARY_VEC_NUM_BYTES / sizeof (int4 );
142141
143- #pragma unroll
144- for (uint32_t j = 0 ; j < INPUT_PRIMARY_INT4_SIZE; ++j) {
145- input_primary_int4[j] = ld_global_nc (
146- reinterpret_cast <const int4 *>(input + group_id * group_size + lane_id * INPUT_PRIMARY_VEC_SIZE) + j);
147- }
142+ // TODO consider stride
143+ const int group_id = token_idx * hidden_size_num_groups + group_start_hidden_idx;
148144
149- scale_element_t * scale_output;
150- if constexpr (IS_COLUMN_MAJOR) {
151- constexpr int scale_token_stride = 1 ;
152- constexpr int num_elems_per_pack = static_cast <int >(sizeof (scale_packed_t ) / sizeof (scale_element_t ));
153-
154- // TODO unify w/ other places?
155- const int scale_hidden_size_unpacked = scale_hidden_size * num_elems_per_pack;
156- const int token_idx = group_id / scale_hidden_size_unpacked;
157- const int group_start_hidden_idx = group_id % scale_hidden_size_unpacked;
158-
159- const int hidden_idx_packed = group_start_hidden_idx / num_elems_per_pack;
160- const int pack_idx = group_start_hidden_idx % num_elems_per_pack;
161- scale_output = reinterpret_cast <scale_element_t *>(output_s) +
162- (hidden_idx_packed * scale_hidden_stride * num_elems_per_pack +
163- token_idx * scale_token_stride * num_elems_per_pack + pack_idx);
164- } else {
165- static_assert (!SCALE_UE8M0);
166- scale_output = output_s + group_id;
167- }
145+ int4 input_primary_int4[INPUT_PRIMARY_INT4_SIZE];
146+ T* input_primary_vec = reinterpret_cast <T*>(input_primary_int4);
147+ static_assert (sizeof (input_primary_vec[0 ]) * INPUT_PRIMARY_VEC_SIZE == sizeof (input_primary_int4));
168148
169- float local_absmax = LOCAL_ABSMAX_ABS;
149+ #pragma unroll
150+ for (uint32_t j = 0 ; j < INPUT_PRIMARY_INT4_SIZE; ++j) {
151+ input_primary_int4[j] = ld_global_nc (
152+ reinterpret_cast <const int4 *>(input + group_id * group_size + lane_id * INPUT_PRIMARY_VEC_SIZE) + j);
153+ }
154+
155+ scale_element_t * scale_output;
156+ if constexpr (IS_COLUMN_MAJOR) {
157+ constexpr int scale_token_stride = 1 ;
158+ constexpr int num_elems_per_pack = static_cast <int >(sizeof (scale_packed_t ) / sizeof (scale_element_t ));
159+
160+ const int hidden_idx_packed = group_start_hidden_idx / num_elems_per_pack;
161+ const int pack_idx = group_start_hidden_idx % num_elems_per_pack;
162+ scale_output = reinterpret_cast <scale_element_t *>(output_s) +
163+ (hidden_idx_packed * scale_hidden_stride * num_elems_per_pack +
164+ token_idx * scale_token_stride * num_elems_per_pack + pack_idx);
165+ } else {
166+ static_assert (!SCALE_UE8M0);
167+ scale_output = output_s + group_id;
168+ }
169+
170+ float local_absmax = LOCAL_ABSMAX_ABS;
170171
171172#pragma unroll
172- for (uint32_t j = 0 ; j < INPUT_PRIMARY_VEC_SIZE; ++j) {
173- float val = static_cast <float >(input_primary_vec[j]);
174- float abs_val = fabsf (val);
175- local_absmax = fmaxf (local_absmax, abs_val);
176- }
173+ for (uint32_t j = 0 ; j < INPUT_PRIMARY_VEC_SIZE; ++j) {
174+ float val = static_cast <float >(input_primary_vec[j]);
175+ float abs_val = fabsf (val);
176+ local_absmax = fmaxf (local_absmax, abs_val);
177+ }
177178
178- local_absmax = GroupReduceMax<THREADS_PER_SUBWARP>(local_absmax, lane_id);
179+ local_absmax = GroupReduceMax<THREADS_PER_SUBWARP>(local_absmax, lane_id);
179180
180- float y_scale, y_scale_inv;
181- calculate_fp8_scales<SCALE_UE8M0, dst_dtype_info>(local_absmax, y_scale, y_scale_inv);
182- float2 y_scale_repeated = {y_scale, y_scale};
181+ float y_scale, y_scale_inv;
182+ calculate_fp8_scales<SCALE_UE8M0, dst_dtype_info>(local_absmax, y_scale, y_scale_inv);
183+ float2 y_scale_repeated = {y_scale, y_scale};
183184
184- if (lane_id == 0 ) {
185- *scale_output = extract_required_scale_format<SCALE_UE8M0>(y_scale_inv);
186- }
185+ if (lane_id == 0 ) {
186+ *scale_output = extract_required_scale_format<SCALE_UE8M0>(y_scale_inv);
187+ }
187188
188- int4 output_buf;
189- static_assert (sizeof (output_buf) == INPUT_PRIMARY_VEC_SIZE * sizeof (DST_DTYPE));
189+ int4 output_buf;
190+ static_assert (sizeof (output_buf) == INPUT_PRIMARY_VEC_SIZE * sizeof (DST_DTYPE));
190191
191- if constexpr (std::is_same_v<DST_DTYPE, c10::Float8_e4m3fn>) {
192- const auto output_buf_ptr = reinterpret_cast <__nv_fp8x2_storage_t *>(&output_buf);
193- static_assert (sizeof (output_buf) == INPUT_PRIMARY_VEC_SIZE / 2 * sizeof (__nv_fp8x2_storage_t ));
194- static_assert (INPUT_PRIMARY_VEC_SIZE % 2 == 0 );
192+ if constexpr (std::is_same_v<DST_DTYPE, c10::Float8_e4m3fn>) {
193+ const auto output_buf_ptr = reinterpret_cast <__nv_fp8x2_storage_t *>(&output_buf);
194+ static_assert (sizeof (output_buf) == INPUT_PRIMARY_VEC_SIZE / 2 * sizeof (__nv_fp8x2_storage_t ));
195+ static_assert (INPUT_PRIMARY_VEC_SIZE % 2 == 0 );
195196
196197#pragma unroll
197- for (uint32_t j = 0 ; j < INPUT_PRIMARY_VEC_SIZE; j += 2 ) {
198- float2 inputx2 = {static_cast <float >(input_primary_vec[j]), static_cast <float >(input_primary_vec[j + 1 ])};
199- float2 outputx2 = __fmul2_rn (inputx2, y_scale_repeated);
200- output_buf_ptr[j / 2 ] = __nv_cvt_float2_to_fp8x2 (outputx2, __NV_SATFINITE, __NV_E4M3);
201- }
202- } else {
203- const auto output_buf_ptr = reinterpret_cast <DST_DTYPE*>(&output_buf);
198+ for (uint32_t j = 0 ; j < INPUT_PRIMARY_VEC_SIZE; j += 2 ) {
199+ float2 inputx2 = {static_cast <float >(input_primary_vec[j]), static_cast <float >(input_primary_vec[j + 1 ])};
200+ float2 outputx2 = __fmul2_rn (inputx2, y_scale_repeated);
201+ output_buf_ptr[j / 2 ] = __nv_cvt_float2_to_fp8x2 (outputx2, __NV_SATFINITE, __NV_E4M3);
202+ }
203+ } else {
204+ const auto output_buf_ptr = reinterpret_cast <DST_DTYPE*>(&output_buf);
204205
205206#pragma unroll
206- for (uint32_t j = 0 ; j < INPUT_PRIMARY_VEC_SIZE; ++j) {
207- float val = static_cast <float >(input_primary_vec[j]);
208- float q_val = fminf (fmaxf (val * y_scale, dst_dtype_info::MIN), dst_dtype_info::MAX);
209- output_buf_ptr[j] = DST_DTYPE (q_val);
210- }
211- }
212-
213- st_global (reinterpret_cast <int4 *>(output_q + group_id * group_size + lane_id * INPUT_PRIMARY_VEC_SIZE), output_buf);
214- });
207+ for (uint32_t j = 0 ; j < INPUT_PRIMARY_VEC_SIZE; ++j) {
208+ float val = static_cast <float >(input_primary_vec[j]);
209+ float q_val = fminf (fmaxf (val * y_scale, dst_dtype_info::MIN), dst_dtype_info::MAX);
210+ output_buf_ptr[j] = DST_DTYPE (q_val);
211+ }
212+ }
213+
214+ st_global (
215+ reinterpret_cast <int4 *>(output_q + group_id * group_size + lane_id * INPUT_PRIMARY_VEC_SIZE), output_buf);
216+ });
215217}
216218
217219int compute_subwarps_per_block (int num_groups) {
@@ -253,7 +255,7 @@ void sgl_per_token_group_quant_8bit(
253255 auto dst_type = output_q.scalar_type ();
254256
255257 const bool is_column_major = output_s.stride (-2 ) < output_s.stride (-1 );
256- const int scale_hidden_size = output_s .size (-1 );
258+ const int hidden_size_num_groups = output_q .size (-1 ) / group_size ;
257259 const int scale_hidden_stride = output_s.stride (-1 );
258260
259261#define LAUNCH_KERNEL_INNER (SCHEDULER, T, DST_DTYPE, output_s_dtype, ...) \
@@ -267,7 +269,7 @@ void sgl_per_token_group_quant_8bit(
267269 static_cast <output_s_dtype*>(output_s.data_ptr ()), \
268270 group_size, \
269271 subwarps_per_block, \
270- scale_hidden_size, \
272+ hidden_size_num_groups, \
271273 scale_hidden_stride); \
272274 } while (0 )
273275
0 commit comments