@@ -91,6 +91,103 @@ __global__ void apply_rotary_pos_emb(__half* mixed_query,
9191 }
9292#endif
9393}
94+ __global__ void apply_rotary_pos_emb1 (float * mixed_query,
95+ float * key_layer,
96+ unsigned rotary_dim,
97+ unsigned seq_len,
98+ unsigned seq_offset,
99+ unsigned num_heads,
100+ unsigned head_size,
101+ unsigned total_count)
102+ {
103+ cg::thread_block b = cg::this_thread_block ();
104+ cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
105+
106+ int id = threadIdx .x ;
107+ int gid = id >> 5 ;
108+ int lane = id & 0x1f ;
109+
110+ unsigned head_id = blockIdx .x * MAX_WARP_NUM + gid;
111+ unsigned offset = head_id * head_size;
112+
113+ unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
114+
115+ if (head_id < total_count) {
116+ while (lane < rotary_dim) {
117+ float inv_freq = (float )((lane / 2 ) * 2 ) / (float )rotary_dim;
118+ inv_freq = 1.0 / powf (10000.0 , inv_freq) * (float )seq_id;
119+ float q = mixed_query[offset + lane];
120+ float k = key_layer[offset + lane];
121+ float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0 );
122+ float q_rot = (q * rotary_sign);
123+ float k_rot = (k * rotary_sign);
124+ q_rot = g.shfl_xor (q_rot, 1 );
125+ k_rot = g.shfl_xor (k_rot, 1 );
126+ q = q * cosf (inv_freq) + q_rot * sinf (inv_freq);
127+ k = k * cosf (inv_freq) + k_rot * sinf (inv_freq);
128+
129+ mixed_query[offset + lane] = q;
130+ key_layer[offset + lane] = k;
131+
132+ lane += WARP_SIZE;
133+ }
134+ }
135+ }
136+ __global__ void apply_rotary_pos_emb1 (__half* mixed_query,
137+ __half* key_layer,
138+ unsigned rotary_dim,
139+ unsigned seq_len,
140+ unsigned seq_offset,
141+ unsigned num_heads,
142+ unsigned head_size,
143+ unsigned total_count)
144+ {
145+ #if __CUDA_ARCH__ >= 700
146+ cg::thread_block b = cg::this_thread_block ();
147+ cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
148+
149+ int id = threadIdx .x ;
150+ int gid = id >> 5 ;
151+ int lane = id & 0x1f ;
152+
153+ unsigned head_id = blockIdx .x * MAX_WARP_NUM + gid;
154+ unsigned offset = head_id * head_size;
155+
156+ constexpr unsigned mask[32 ] = {
157+ 0x1 | 0x1000 , 0x2 | 0x2000 , 0x4 | 0x4000 , 0x8 | 0x8000 , 0x10 | 0x10000 ,
158+ 0x20 | 0x20000 , 0x40 | 0x40000 , 0x80 | 0x80000 , 0x100 | 0x100000 , 0x200 | 0x200000 ,
159+ 0x400 | 0x400000 , 0x800 | 0x800000 , 0x1000 | 0x1 , 0x2000 | 0x2 , 0x4000 | 0x4 ,
160+ 0x8000 | 0x8 , 0x10000 | 0x10 , 0x20000 | 0x20 , 0x40000 | 0x40 , 0x80000 | 0x80 ,
161+ 0x100000 | 0x100 , 0x200000 | 0x200 , 0x400000 | 0x400 , 0x800000 | 0x800 , 0x1000000 ,
162+ 0x2000000 , 0x4000000 , 0x8000000 , 0x10000000 , 0x20000000 ,
163+ 0x40000000 , 0x80000000 };
164+
165+ unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
166+ unsigned half_dim = rotary_dim >> 1 ;
167+ if (head_id < total_count) {
168+ while (lane < rotary_dim) {
169+ float inv_freq = (float )((lane % half_dim) * 2 ) / (float )rotary_dim;
170+ inv_freq = 1.0 / powf (10000.0 , inv_freq) * (float )seq_id;
171+ float q = (float )mixed_query[offset + lane];
172+ float k = (float )key_layer[offset + lane];
173+ float rotary_sign = (lane > (half_dim - 1 ) ? -1.0 : 1.0 );
174+ float q_rot = (q * rotary_sign);
175+ float k_rot = (k * rotary_sign);
176+ auto q_rot_tmp = lane < half_dim ? __shfl_sync (mask[lane], q_rot, lane + half_dim)
177+ : __shfl_sync (mask[lane], q_rot, lane - half_dim);
178+ auto k_rot_tmp = lane < half_dim ? __shfl_sync (mask[lane], k_rot, lane + half_dim)
179+ : __shfl_sync (mask[lane], k_rot, lane - half_dim);
180+ q = q * cosf (inv_freq) + q_rot_tmp * sinf (inv_freq);
181+ k = k * cosf (inv_freq) + k_rot_tmp * sinf (inv_freq);
182+
183+ mixed_query[offset + lane] = (__half)q;
184+ key_layer[offset + lane] = (__half)k;
185+
186+ lane += WARP_SIZE;
187+ }
188+ }
189+ #endif
190+ }
94191
95192template <typename T>
96193void launch_apply_rotary_pos_emb (T* mixed_query,
@@ -101,14 +198,19 @@ void launch_apply_rotary_pos_emb(T* mixed_query,
101198 unsigned offset,
102199 unsigned num_heads,
103200 unsigned batch,
201+ bool rotate_half,
202+ bool rotate_every_two,
104203 cudaStream_t stream)
105204{
106205 int total_count = batch * num_heads * seq_len;
107206 dim3 block_dims (1024 );
108207 dim3 grid_dims ((total_count - 1 ) / MAX_WARP_NUM + 1 ); // (batch_size);
109-
110- apply_rotary_pos_emb<<<grid_dims, block_dims, 0 , stream>>> (
111- mixed_query, key_layer, rotary_dim, seq_len, offset, num_heads, head_size, total_count);
208+ if (rotate_every_two)
209+ apply_rotary_pos_emb<<<grid_dims, block_dims, 0 , stream>>> (
210+ mixed_query, key_layer, rotary_dim, seq_len, offset, num_heads, head_size, total_count);
211+ else if (rotate_half)
212+ apply_rotary_pos_emb1<<<grid_dims, block_dims, 0 , stream>>> (
213+ mixed_query, key_layer, rotary_dim, seq_len, offset, num_heads, head_size, total_count);
112214}
113215
114216template void launch_apply_rotary_pos_emb<float >(float *,
@@ -119,6 +221,8 @@ template void launch_apply_rotary_pos_emb<float>(float*,
119221 unsigned ,
120222 unsigned ,
121223 unsigned ,
224+ bool ,
225+ bool ,
122226 cudaStream_t);
123227template void launch_apply_rotary_pos_emb<__half>(__half*,
124228 __half*,
@@ -128,4 +232,141 @@ template void launch_apply_rotary_pos_emb<__half>(__half*,
128232 unsigned ,
129233 unsigned ,
130234 unsigned ,
235+ bool ,
236+ bool ,
131237 cudaStream_t);
238+ /*
239+ __global__ void apply_rotary_pos_emb(float* mixed_query,
240+ float* key_layer,
241+ unsigned rotary_dim,
242+ unsigned seq_len,
243+ unsigned seq_offset,
244+ unsigned num_heads,
245+ unsigned head_size,
246+ unsigned total_count)
247+ {
248+ cg::thread_block b = cg::this_thread_block();
249+ cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
250+
251+ int id = threadIdx.x;
252+ int gid = id >> 5;
253+ int lane = id & 0x1f;
254+
255+ unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
256+ unsigned offset = head_id * head_size;
257+
258+ unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
259+
260+ if (head_id < total_count) {
261+ while (lane < rotary_dim) {
262+ float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
263+ inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
264+ float q = mixed_query[offset + lane];
265+ float k = key_layer[offset + lane];
266+ float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
267+ float q_rot = (q * rotary_sign);
268+ float k_rot = (k * rotary_sign);
269+ q_rot = g.shfl_xor(q_rot, 1);
270+ k_rot = g.shfl_xor(k_rot, 1);
271+ q = q * cosf(inv_freq) + q_rot * sinf(inv_freq);
272+ k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);
273+
274+ mixed_query[offset + lane] = q;
275+ key_layer[offset + lane] = k;
276+
277+ lane += WARP_SIZE;
278+ }
279+ }
280+ }
281+
282+ __global__ void apply_rotary_pos_emb(__half* mixed_query,
283+ __half* key_layer,
284+ unsigned rotary_dim,
285+ unsigned seq_len,
286+ unsigned seq_offset,
287+ unsigned num_heads,
288+ unsigned head_size,
289+ unsigned total_count)
290+ {
291+ #if __CUDA_ARCH__ >= 700
292+ cg::thread_block b = cg::this_thread_block();
293+ cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
294+
295+ int id = threadIdx.x;
296+ int gid = id >> 5;
297+ int lane = id & 0x1f;
298+
299+ unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
300+ unsigned offset = head_id * head_size;
301+ constexpr unsigned mask[32] = {0x1 | 0x1000, 0x2 | 0x2000, 0x4 | 0x4000, 0x8 | 0x8000,
302+ 0x10 | 0x10000, 0x20 | 0x20000, 0x40 | 0x40000, 0x80 | 0x80000,
303+ 0x100 | 0x100000, 0x200 | 0x200000, 0x400 | 0x400000, 0x800 | 0x800000,
304+ 0x1000 | 0x1, 0x2000 | 0x2, 0x4000 | 0x4, 0x8000 | 0x8,
305+ 0x10000 | 0x10, 0x20000 | 0x20, 0x40000 | 0x40, 0x80000 | 0x80,
306+ 0x100000 | 0x100, 0x200000 | 0x200, 0x400000 | 0x400, 0x800000 | 0x800,
307+ 0x1000000, 0x2000000, 0x4000000, 0x8000000,
308+ 0x10000000, 0x20000000, 0x40000000, 0x80000000};
309+ unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
310+
311+ if (head_id < total_count) {
312+ while (lane < rotary_dim) {
313+ //float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
314+ float inv_freq = (float)((lane % (rotary_dim >> 1)) * 2) / (float)rotary_dim;
315+ inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
316+ float q = (float)mixed_query[offset + lane];
317+ float k = (float)key_layer[offset + lane];
318+ float rotary_sign = (lane > 11 ? -1.0 : 1.0);
319+ float q_rot = (q * rotary_sign);
320+ float k_rot = (k * rotary_sign);
321+ auto q_rot_tmp = lane < 12 ? __shfl_sync(mask[lane], q_rot, lane + 12) : __shfl_sync(mask[lane],
322+ q_rot, lane - 12);//g.shfl_xor(q_rot, 12); auto k_rot_tmp = lane < 12 ? __shfl_sync(mask[lane],
323+ k_rot, lane + 12) : __shfl_sync(mask[lane], k_rot, lane - 12);//g.shfl_xor(k_rot, 12); q = q *
324+ cosf(inv_freq) + q_rot_tmp * sinf(inv_freq); k = k * cosf(inv_freq) + k_rot_tmp * sinf(inv_freq);
325+
326+ mixed_query[offset + lane] = (__half)q;
327+ key_layer[offset + lane] = (__half)k;
328+
329+ lane += WARP_SIZE;
330+ }
331+ }
332+ #endif
333+ }
334+
335+ template <typename T>
336+ void launch_apply_rotary_pos_emb(T* mixed_query,
337+ T* key_layer,
338+ unsigned head_size,
339+ unsigned seq_len,
340+ unsigned rotary_dim,
341+ unsigned offset,
342+ unsigned num_heads,
343+ unsigned batch,
344+ cudaStream_t stream)
345+ {
346+ int total_count = batch * num_heads * seq_len;
347+ dim3 block_dims(1024);
348+ dim3 grid_dims((total_count - 1) / MAX_WARP_NUM + 1); // (batch_size);
349+
350+ apply_rotary_pos_emb<<<grid_dims, block_dims, 0, stream>>>(
351+ mixed_query, key_layer, rotary_dim, seq_len, offset, num_heads, head_size, total_count);
352+ }
353+
354+ template void launch_apply_rotary_pos_emb<float>(float*,
355+ float*,
356+ unsigned,
357+ unsigned,
358+ unsigned,
359+ unsigned,
360+ unsigned,
361+ unsigned,
362+ cudaStream_t);
363+ template void launch_apply_rotary_pos_emb<__half>(__half*,
364+ __half*,
365+ unsigned,
366+ unsigned,
367+ unsigned,
368+ unsigned,
369+ unsigned,
370+ unsigned,
371+ cudaStream_t);
372+ */
0 commit comments