Skip to content

Commit c7f3d13

Browse files
committed
more
1 parent e14d6d8 commit c7f3d13

1 file changed

Lines changed: 75 additions & 73 deletions

File tree

sgl-kernel/csrc/gemm/per_token_group_quant_8bit.cu

Lines changed: 75 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -102,13 +102,16 @@ struct NaiveScheduler {
102102
}
103103

104104
template <typename FUNC>
105-
__device__ __forceinline__ static void execute(int subwarps_per_block, FUNC fn) {
105+
__device__ __forceinline__ static void execute(int subwarps_per_block, int hidden_size_num_groups, FUNC fn) {
106106
const int local_group_id = threadIdx.x / THREADS_PER_SUBWARP;
107107
const int lane_id = threadIdx.x % THREADS_PER_SUBWARP;
108108
const int block_group_id = blockIdx.x * subwarps_per_block;
109109
const int group_id = block_group_id + local_group_id;
110110

111-
fn(group_id, lane_id);
111+
const int token_idx = group_id / hidden_size_num_groups;
112+
const int group_start_hidden_idx = group_id % hidden_size_num_groups;
113+
114+
fn(token_idx, group_start_hidden_idx, lane_id);
112115
}
113116
};
114117

@@ -125,93 +128,92 @@ __global__ void per_token_group_quant_8bit_kernel(
125128
scale_packed_t* __restrict__ output_s,
126129
const int group_size,
127130
const int subwarps_per_block,
128-
// TODO can remove?
129-
const int scale_hidden_size = 0,
130-
const int scale_hidden_stride = 0) {
131+
const int hidden_size_num_groups,
132+
const int scale_hidden_stride) {
131133
using dst_dtype_info = DtypeInfo<DST_DTYPE>;
132134
using scale_element_t = std::conditional_t<SCALE_UE8M0, uint8_t, float>;
133135
static_assert(sizeof(scale_packed_t) % sizeof(scale_element_t) == 0);
134136

135-
SCHEDULER::execute(subwarps_per_block, [&](int group_id, int lane_id) {
136-
constexpr uint32_t INPUT_PRIMARY_VEC_SIZE = INPUT_PRIMARY_VEC_NUM_BYTES / sizeof(T);
137-
constexpr uint32_t INPUT_PRIMARY_INT4_SIZE = INPUT_PRIMARY_VEC_NUM_BYTES / sizeof(int4);
138-
139-
int4 input_primary_int4[INPUT_PRIMARY_INT4_SIZE];
140-
T* input_primary_vec = reinterpret_cast<T*>(input_primary_int4);
141-
static_assert(sizeof(input_primary_vec[0]) * INPUT_PRIMARY_VEC_SIZE == sizeof(input_primary_int4));
137+
SCHEDULER::execute(
138+
subwarps_per_block, hidden_size_num_groups, [&](int token_idx, int group_start_hidden_idx, int lane_id) {
139+
constexpr uint32_t INPUT_PRIMARY_VEC_SIZE = INPUT_PRIMARY_VEC_NUM_BYTES / sizeof(T);
140+
constexpr uint32_t INPUT_PRIMARY_INT4_SIZE = INPUT_PRIMARY_VEC_NUM_BYTES / sizeof(int4);
142141

143-
#pragma unroll
144-
for (uint32_t j = 0; j < INPUT_PRIMARY_INT4_SIZE; ++j) {
145-
input_primary_int4[j] = ld_global_nc(
146-
reinterpret_cast<const int4*>(input + group_id * group_size + lane_id * INPUT_PRIMARY_VEC_SIZE) + j);
147-
}
142+
// TODO consider stride
143+
const int group_id = token_idx * hidden_size_num_groups + group_start_hidden_idx;
148144

149-
scale_element_t* scale_output;
150-
if constexpr (IS_COLUMN_MAJOR) {
151-
constexpr int scale_token_stride = 1;
152-
constexpr int num_elems_per_pack = static_cast<int>(sizeof(scale_packed_t) / sizeof(scale_element_t));
153-
154-
// TODO unify w/ other places?
155-
const int scale_hidden_size_unpacked = scale_hidden_size * num_elems_per_pack;
156-
const int token_idx = group_id / scale_hidden_size_unpacked;
157-
const int group_start_hidden_idx = group_id % scale_hidden_size_unpacked;
158-
159-
const int hidden_idx_packed = group_start_hidden_idx / num_elems_per_pack;
160-
const int pack_idx = group_start_hidden_idx % num_elems_per_pack;
161-
scale_output = reinterpret_cast<scale_element_t*>(output_s) +
162-
(hidden_idx_packed * scale_hidden_stride * num_elems_per_pack +
163-
token_idx * scale_token_stride * num_elems_per_pack + pack_idx);
164-
} else {
165-
static_assert(!SCALE_UE8M0);
166-
scale_output = output_s + group_id;
167-
}
145+
int4 input_primary_int4[INPUT_PRIMARY_INT4_SIZE];
146+
T* input_primary_vec = reinterpret_cast<T*>(input_primary_int4);
147+
static_assert(sizeof(input_primary_vec[0]) * INPUT_PRIMARY_VEC_SIZE == sizeof(input_primary_int4));
168148

169-
float local_absmax = LOCAL_ABSMAX_ABS;
149+
#pragma unroll
150+
for (uint32_t j = 0; j < INPUT_PRIMARY_INT4_SIZE; ++j) {
151+
input_primary_int4[j] = ld_global_nc(
152+
reinterpret_cast<const int4*>(input + group_id * group_size + lane_id * INPUT_PRIMARY_VEC_SIZE) + j);
153+
}
154+
155+
scale_element_t* scale_output;
156+
if constexpr (IS_COLUMN_MAJOR) {
157+
constexpr int scale_token_stride = 1;
158+
constexpr int num_elems_per_pack = static_cast<int>(sizeof(scale_packed_t) / sizeof(scale_element_t));
159+
160+
const int hidden_idx_packed = group_start_hidden_idx / num_elems_per_pack;
161+
const int pack_idx = group_start_hidden_idx % num_elems_per_pack;
162+
scale_output = reinterpret_cast<scale_element_t*>(output_s) +
163+
(hidden_idx_packed * scale_hidden_stride * num_elems_per_pack +
164+
token_idx * scale_token_stride * num_elems_per_pack + pack_idx);
165+
} else {
166+
static_assert(!SCALE_UE8M0);
167+
scale_output = output_s + group_id;
168+
}
169+
170+
float local_absmax = LOCAL_ABSMAX_ABS;
170171

171172
#pragma unroll
172-
for (uint32_t j = 0; j < INPUT_PRIMARY_VEC_SIZE; ++j) {
173-
float val = static_cast<float>(input_primary_vec[j]);
174-
float abs_val = fabsf(val);
175-
local_absmax = fmaxf(local_absmax, abs_val);
176-
}
173+
for (uint32_t j = 0; j < INPUT_PRIMARY_VEC_SIZE; ++j) {
174+
float val = static_cast<float>(input_primary_vec[j]);
175+
float abs_val = fabsf(val);
176+
local_absmax = fmaxf(local_absmax, abs_val);
177+
}
177178

178-
local_absmax = GroupReduceMax<THREADS_PER_SUBWARP>(local_absmax, lane_id);
179+
local_absmax = GroupReduceMax<THREADS_PER_SUBWARP>(local_absmax, lane_id);
179180

180-
float y_scale, y_scale_inv;
181-
calculate_fp8_scales<SCALE_UE8M0, dst_dtype_info>(local_absmax, y_scale, y_scale_inv);
182-
float2 y_scale_repeated = {y_scale, y_scale};
181+
float y_scale, y_scale_inv;
182+
calculate_fp8_scales<SCALE_UE8M0, dst_dtype_info>(local_absmax, y_scale, y_scale_inv);
183+
float2 y_scale_repeated = {y_scale, y_scale};
183184

184-
if (lane_id == 0) {
185-
*scale_output = extract_required_scale_format<SCALE_UE8M0>(y_scale_inv);
186-
}
185+
if (lane_id == 0) {
186+
*scale_output = extract_required_scale_format<SCALE_UE8M0>(y_scale_inv);
187+
}
187188

188-
int4 output_buf;
189-
static_assert(sizeof(output_buf) == INPUT_PRIMARY_VEC_SIZE * sizeof(DST_DTYPE));
189+
int4 output_buf;
190+
static_assert(sizeof(output_buf) == INPUT_PRIMARY_VEC_SIZE * sizeof(DST_DTYPE));
190191

191-
if constexpr (std::is_same_v<DST_DTYPE, c10::Float8_e4m3fn>) {
192-
const auto output_buf_ptr = reinterpret_cast<__nv_fp8x2_storage_t*>(&output_buf);
193-
static_assert(sizeof(output_buf) == INPUT_PRIMARY_VEC_SIZE / 2 * sizeof(__nv_fp8x2_storage_t));
194-
static_assert(INPUT_PRIMARY_VEC_SIZE % 2 == 0);
192+
if constexpr (std::is_same_v<DST_DTYPE, c10::Float8_e4m3fn>) {
193+
const auto output_buf_ptr = reinterpret_cast<__nv_fp8x2_storage_t*>(&output_buf);
194+
static_assert(sizeof(output_buf) == INPUT_PRIMARY_VEC_SIZE / 2 * sizeof(__nv_fp8x2_storage_t));
195+
static_assert(INPUT_PRIMARY_VEC_SIZE % 2 == 0);
195196

196197
#pragma unroll
197-
for (uint32_t j = 0; j < INPUT_PRIMARY_VEC_SIZE; j += 2) {
198-
float2 inputx2 = {static_cast<float>(input_primary_vec[j]), static_cast<float>(input_primary_vec[j + 1])};
199-
float2 outputx2 = __fmul2_rn(inputx2, y_scale_repeated);
200-
output_buf_ptr[j / 2] = __nv_cvt_float2_to_fp8x2(outputx2, __NV_SATFINITE, __NV_E4M3);
201-
}
202-
} else {
203-
const auto output_buf_ptr = reinterpret_cast<DST_DTYPE*>(&output_buf);
198+
for (uint32_t j = 0; j < INPUT_PRIMARY_VEC_SIZE; j += 2) {
199+
float2 inputx2 = {static_cast<float>(input_primary_vec[j]), static_cast<float>(input_primary_vec[j + 1])};
200+
float2 outputx2 = __fmul2_rn(inputx2, y_scale_repeated);
201+
output_buf_ptr[j / 2] = __nv_cvt_float2_to_fp8x2(outputx2, __NV_SATFINITE, __NV_E4M3);
202+
}
203+
} else {
204+
const auto output_buf_ptr = reinterpret_cast<DST_DTYPE*>(&output_buf);
204205

205206
#pragma unroll
206-
for (uint32_t j = 0; j < INPUT_PRIMARY_VEC_SIZE; ++j) {
207-
float val = static_cast<float>(input_primary_vec[j]);
208-
float q_val = fminf(fmaxf(val * y_scale, dst_dtype_info::MIN), dst_dtype_info::MAX);
209-
output_buf_ptr[j] = DST_DTYPE(q_val);
210-
}
211-
}
212-
213-
st_global(reinterpret_cast<int4*>(output_q + group_id * group_size + lane_id * INPUT_PRIMARY_VEC_SIZE), output_buf);
214-
});
207+
for (uint32_t j = 0; j < INPUT_PRIMARY_VEC_SIZE; ++j) {
208+
float val = static_cast<float>(input_primary_vec[j]);
209+
float q_val = fminf(fmaxf(val * y_scale, dst_dtype_info::MIN), dst_dtype_info::MAX);
210+
output_buf_ptr[j] = DST_DTYPE(q_val);
211+
}
212+
}
213+
214+
st_global(
215+
reinterpret_cast<int4*>(output_q + group_id * group_size + lane_id * INPUT_PRIMARY_VEC_SIZE), output_buf);
216+
});
215217
}
216218

217219
int compute_subwarps_per_block(int num_groups) {
@@ -253,7 +255,7 @@ void sgl_per_token_group_quant_8bit(
253255
auto dst_type = output_q.scalar_type();
254256

255257
const bool is_column_major = output_s.stride(-2) < output_s.stride(-1);
256-
const int scale_hidden_size = output_s.size(-1);
258+
const int hidden_size_num_groups = output_q.size(-1) / group_size;
257259
const int scale_hidden_stride = output_s.stride(-1);
258260

259261
#define LAUNCH_KERNEL_INNER(SCHEDULER, T, DST_DTYPE, output_s_dtype, ...) \
@@ -267,7 +269,7 @@ void sgl_per_token_group_quant_8bit(
267269
static_cast<output_s_dtype*>(output_s.data_ptr()), \
268270
group_size, \
269271
subwarps_per_block, \
270-
scale_hidden_size, \
272+
hidden_size_num_groups, \
271273
scale_hidden_stride); \
272274
} while (0)
273275

0 commit comments

Comments
 (0)