Skip to content

Commit b4fcd98

Browse files
jeffraReza YazdaniRezaYazdaniAminabadi
authored
Inference PP changes for neox (#1899)
Co-authored-by: Reza Yazdani <reyazda@microsoft.com> Co-authored-by: Reza Yazdani <44502768+RezaYazdaniAminabadi@users.noreply.github.com>
1 parent 32d9797 commit b4fcd98

File tree

17 files changed

+783
-260
lines changed

17 files changed

+783
-260
lines changed

csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu

Lines changed: 244 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,103 @@ __global__ void apply_rotary_pos_emb(__half* mixed_query,
9191
}
9292
#endif
9393
}
94+
__global__ void apply_rotary_pos_emb1(float* mixed_query,
95+
float* key_layer,
96+
unsigned rotary_dim,
97+
unsigned seq_len,
98+
unsigned seq_offset,
99+
unsigned num_heads,
100+
unsigned head_size,
101+
unsigned total_count)
102+
{
103+
cg::thread_block b = cg::this_thread_block();
104+
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
105+
106+
int id = threadIdx.x;
107+
int gid = id >> 5;
108+
int lane = id & 0x1f;
109+
110+
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
111+
unsigned offset = head_id * head_size;
112+
113+
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
114+
115+
if (head_id < total_count) {
116+
while (lane < rotary_dim) {
117+
float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
118+
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
119+
float q = mixed_query[offset + lane];
120+
float k = key_layer[offset + lane];
121+
float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
122+
float q_rot = (q * rotary_sign);
123+
float k_rot = (k * rotary_sign);
124+
q_rot = g.shfl_xor(q_rot, 1);
125+
k_rot = g.shfl_xor(k_rot, 1);
126+
q = q * cosf(inv_freq) + q_rot * sinf(inv_freq);
127+
k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);
128+
129+
mixed_query[offset + lane] = q;
130+
key_layer[offset + lane] = k;
131+
132+
lane += WARP_SIZE;
133+
}
134+
}
135+
}
136+
__global__ void apply_rotary_pos_emb1(__half* mixed_query,
137+
__half* key_layer,
138+
unsigned rotary_dim,
139+
unsigned seq_len,
140+
unsigned seq_offset,
141+
unsigned num_heads,
142+
unsigned head_size,
143+
unsigned total_count)
144+
{
145+
#if __CUDA_ARCH__ >= 700
146+
cg::thread_block b = cg::this_thread_block();
147+
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
148+
149+
int id = threadIdx.x;
150+
int gid = id >> 5;
151+
int lane = id & 0x1f;
152+
153+
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
154+
unsigned offset = head_id * head_size;
155+
156+
constexpr unsigned mask[32] = {
157+
0x1 | 0x1000, 0x2 | 0x2000, 0x4 | 0x4000, 0x8 | 0x8000, 0x10 | 0x10000,
158+
0x20 | 0x20000, 0x40 | 0x40000, 0x80 | 0x80000, 0x100 | 0x100000, 0x200 | 0x200000,
159+
0x400 | 0x400000, 0x800 | 0x800000, 0x1000 | 0x1, 0x2000 | 0x2, 0x4000 | 0x4,
160+
0x8000 | 0x8, 0x10000 | 0x10, 0x20000 | 0x20, 0x40000 | 0x40, 0x80000 | 0x80,
161+
0x100000 | 0x100, 0x200000 | 0x200, 0x400000 | 0x400, 0x800000 | 0x800, 0x1000000,
162+
0x2000000, 0x4000000, 0x8000000, 0x10000000, 0x20000000,
163+
0x40000000, 0x80000000};
164+
165+
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
166+
unsigned half_dim = rotary_dim >> 1;
167+
if (head_id < total_count) {
168+
while (lane < rotary_dim) {
169+
float inv_freq = (float)((lane % half_dim) * 2) / (float)rotary_dim;
170+
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
171+
float q = (float)mixed_query[offset + lane];
172+
float k = (float)key_layer[offset + lane];
173+
float rotary_sign = (lane > (half_dim - 1) ? -1.0 : 1.0);
174+
float q_rot = (q * rotary_sign);
175+
float k_rot = (k * rotary_sign);
176+
auto q_rot_tmp = lane < half_dim ? __shfl_sync(mask[lane], q_rot, lane + half_dim)
177+
: __shfl_sync(mask[lane], q_rot, lane - half_dim);
178+
auto k_rot_tmp = lane < half_dim ? __shfl_sync(mask[lane], k_rot, lane + half_dim)
179+
: __shfl_sync(mask[lane], k_rot, lane - half_dim);
180+
q = q * cosf(inv_freq) + q_rot_tmp * sinf(inv_freq);
181+
k = k * cosf(inv_freq) + k_rot_tmp * sinf(inv_freq);
182+
183+
mixed_query[offset + lane] = (__half)q;
184+
key_layer[offset + lane] = (__half)k;
185+
186+
lane += WARP_SIZE;
187+
}
188+
}
189+
#endif
190+
}
94191

95192
template <typename T>
96193
void launch_apply_rotary_pos_emb(T* mixed_query,
@@ -101,14 +198,19 @@ void launch_apply_rotary_pos_emb(T* mixed_query,
101198
unsigned offset,
102199
unsigned num_heads,
103200
unsigned batch,
201+
bool rotate_half,
202+
bool rotate_every_two,
104203
cudaStream_t stream)
105204
{
106205
int total_count = batch * num_heads * seq_len;
107206
dim3 block_dims(1024);
108207
dim3 grid_dims((total_count - 1) / MAX_WARP_NUM + 1); // (batch_size);
109-
110-
apply_rotary_pos_emb<<<grid_dims, block_dims, 0, stream>>>(
111-
mixed_query, key_layer, rotary_dim, seq_len, offset, num_heads, head_size, total_count);
208+
if (rotate_every_two)
209+
apply_rotary_pos_emb<<<grid_dims, block_dims, 0, stream>>>(
210+
mixed_query, key_layer, rotary_dim, seq_len, offset, num_heads, head_size, total_count);
211+
else if (rotate_half)
212+
apply_rotary_pos_emb1<<<grid_dims, block_dims, 0, stream>>>(
213+
mixed_query, key_layer, rotary_dim, seq_len, offset, num_heads, head_size, total_count);
112214
}
113215

114216
template void launch_apply_rotary_pos_emb<float>(float*,
@@ -119,6 +221,8 @@ template void launch_apply_rotary_pos_emb<float>(float*,
119221
unsigned,
120222
unsigned,
121223
unsigned,
224+
bool,
225+
bool,
122226
cudaStream_t);
123227
template void launch_apply_rotary_pos_emb<__half>(__half*,
124228
__half*,
@@ -128,4 +232,141 @@ template void launch_apply_rotary_pos_emb<__half>(__half*,
128232
unsigned,
129233
unsigned,
130234
unsigned,
235+
bool,
236+
bool,
131237
cudaStream_t);
238+
/*
239+
__global__ void apply_rotary_pos_emb(float* mixed_query,
240+
float* key_layer,
241+
unsigned rotary_dim,
242+
unsigned seq_len,
243+
unsigned seq_offset,
244+
unsigned num_heads,
245+
unsigned head_size,
246+
unsigned total_count)
247+
{
248+
cg::thread_block b = cg::this_thread_block();
249+
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
250+
251+
int id = threadIdx.x;
252+
int gid = id >> 5;
253+
int lane = id & 0x1f;
254+
255+
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
256+
unsigned offset = head_id * head_size;
257+
258+
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
259+
260+
if (head_id < total_count) {
261+
while (lane < rotary_dim) {
262+
float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
263+
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
264+
float q = mixed_query[offset + lane];
265+
float k = key_layer[offset + lane];
266+
float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
267+
float q_rot = (q * rotary_sign);
268+
float k_rot = (k * rotary_sign);
269+
q_rot = g.shfl_xor(q_rot, 1);
270+
k_rot = g.shfl_xor(k_rot, 1);
271+
q = q * cosf(inv_freq) + q_rot * sinf(inv_freq);
272+
k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);
273+
274+
mixed_query[offset + lane] = q;
275+
key_layer[offset + lane] = k;
276+
277+
lane += WARP_SIZE;
278+
}
279+
}
280+
}
281+
282+
__global__ void apply_rotary_pos_emb(__half* mixed_query,
283+
__half* key_layer,
284+
unsigned rotary_dim,
285+
unsigned seq_len,
286+
unsigned seq_offset,
287+
unsigned num_heads,
288+
unsigned head_size,
289+
unsigned total_count)
290+
{
291+
#if __CUDA_ARCH__ >= 700
292+
cg::thread_block b = cg::this_thread_block();
293+
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
294+
295+
int id = threadIdx.x;
296+
int gid = id >> 5;
297+
int lane = id & 0x1f;
298+
299+
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
300+
unsigned offset = head_id * head_size;
301+
constexpr unsigned mask[32] = {0x1 | 0x1000, 0x2 | 0x2000, 0x4 | 0x4000, 0x8 | 0x8000,
302+
0x10 | 0x10000, 0x20 | 0x20000, 0x40 | 0x40000, 0x80 | 0x80000,
303+
0x100 | 0x100000, 0x200 | 0x200000, 0x400 | 0x400000, 0x800 | 0x800000,
304+
0x1000 | 0x1, 0x2000 | 0x2, 0x4000 | 0x4, 0x8000 | 0x8,
305+
0x10000 | 0x10, 0x20000 | 0x20, 0x40000 | 0x40, 0x80000 | 0x80,
306+
0x100000 | 0x100, 0x200000 | 0x200, 0x400000 | 0x400, 0x800000 | 0x800,
307+
0x1000000, 0x2000000, 0x4000000, 0x8000000,
308+
0x10000000, 0x20000000, 0x40000000, 0x80000000};
309+
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
310+
311+
if (head_id < total_count) {
312+
while (lane < rotary_dim) {
313+
//float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
314+
float inv_freq = (float)((lane % (rotary_dim >> 1)) * 2) / (float)rotary_dim;
315+
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
316+
float q = (float)mixed_query[offset + lane];
317+
float k = (float)key_layer[offset + lane];
318+
float rotary_sign = (lane > 11 ? -1.0 : 1.0);
319+
float q_rot = (q * rotary_sign);
320+
float k_rot = (k * rotary_sign);
321+
auto q_rot_tmp = lane < 12 ? __shfl_sync(mask[lane], q_rot, lane + 12) : __shfl_sync(mask[lane],
322+
q_rot, lane - 12);//g.shfl_xor(q_rot, 12); auto k_rot_tmp = lane < 12 ? __shfl_sync(mask[lane],
323+
k_rot, lane + 12) : __shfl_sync(mask[lane], k_rot, lane - 12);//g.shfl_xor(k_rot, 12); q = q *
324+
cosf(inv_freq) + q_rot_tmp * sinf(inv_freq); k = k * cosf(inv_freq) + k_rot_tmp * sinf(inv_freq);
325+
326+
mixed_query[offset + lane] = (__half)q;
327+
key_layer[offset + lane] = (__half)k;
328+
329+
lane += WARP_SIZE;
330+
}
331+
}
332+
#endif
333+
}
334+
335+
template <typename T>
336+
void launch_apply_rotary_pos_emb(T* mixed_query,
337+
T* key_layer,
338+
unsigned head_size,
339+
unsigned seq_len,
340+
unsigned rotary_dim,
341+
unsigned offset,
342+
unsigned num_heads,
343+
unsigned batch,
344+
cudaStream_t stream)
345+
{
346+
int total_count = batch * num_heads * seq_len;
347+
dim3 block_dims(1024);
348+
dim3 grid_dims((total_count - 1) / MAX_WARP_NUM + 1); // (batch_size);
349+
350+
apply_rotary_pos_emb<<<grid_dims, block_dims, 0, stream>>>(
351+
mixed_query, key_layer, rotary_dim, seq_len, offset, num_heads, head_size, total_count);
352+
}
353+
354+
template void launch_apply_rotary_pos_emb<float>(float*,
355+
float*,
356+
unsigned,
357+
unsigned,
358+
unsigned,
359+
unsigned,
360+
unsigned,
361+
unsigned,
362+
cudaStream_t);
363+
template void launch_apply_rotary_pos_emb<__half>(__half*,
364+
__half*,
365+
unsigned,
366+
unsigned,
367+
unsigned,
368+
unsigned,
369+
unsigned,
370+
unsigned,
371+
cudaStream_t);
372+
*/

0 commit comments

Comments
 (0)