@@ -63,11 +63,19 @@ static __global__ void flash_attn_ext_f16(
6363 constexpr int frag_m = ncols == 8 ? 32 : 16 ;
6464 constexpr int frag_n = ncols == 8 ? 8 : 16 ;
6565 static_assert (D % frag_m == 0 , " If ncols == 8 then D % frag_m must be 0." );
66+ #if defined(GGML_USE_HIP)
67+ typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16 , _Float16, wmma::row_major> frag_a_K;
68+ typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16 , _Float16, wmma::col_major> frag_a_V;
69+ typedef wmma::fragment<wmma::matrix_b, frag_m, frag_n, 16 , _Float16, wmma::col_major> frag_b;
70+ typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16 , KQ_acc_t> frag_c_KQ;
71+ typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16 , _Float16> frag_c_VKQ;
72+ #else
6673 typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16 , half, wmma::row_major> frag_a_K;
6774 typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16 , half, wmma::col_major> frag_a_V;
6875 typedef wmma::fragment<wmma::matrix_b, frag_m, frag_n, 16 , half, wmma::col_major> frag_b;
6976 typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16 , KQ_acc_t> frag_c_KQ;
7077 typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16 , half> frag_c_VKQ;
78+ #endif
7179
7280 constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
7381 constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
@@ -126,6 +134,19 @@ static __global__ void flash_attn_ext_f16(
126134
127135 __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
128136 half2 * VKQ2 = (half2 *) VKQ;
137+
138+ #if defined(GGML_USE_HIP)
139+ const _Float16 * K_h_f16 = reinterpret_cast <const _Float16 *>(K_h);
140+ const _Float16 * V_h_f16 = reinterpret_cast <const _Float16 *>(V_h);
141+ _Float16 * KQ_f16 = reinterpret_cast <_Float16 *>(KQ);
142+ _Float16 * VKQ_f16 = reinterpret_cast <_Float16 *>(VKQ);
143+ #else
144+ const half * K_h_f16 = K_h;
145+ const half * V_h_f16 = V_h;
146+ half * KQ_f16 = KQ;
147+ half * VKQ_f16 = VKQ;
148+ #endif
149+
129150#pragma unroll
130151 for (int j0 = 0 ; j0 < ncols; j0 += nwarps) {
131152 const int j = j0 + threadIdx .y ;
@@ -160,7 +181,7 @@ static __global__ void flash_attn_ext_f16(
160181 for (int i0 = 0 ; i0 < D; i0 += 16 ) {
161182#pragma unroll
162183 for (int j0 = 0 ; j0 < ncols; j0 += frag_n) {
163- wmma::load_matrix_sync (Q_b[i0/16 ][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
184+ wmma::load_matrix_sync (Q_b[i0/16 ][j0/frag_n], KQ_f16 + j0*D_padded + i0, D_padded);
164185 }
165186 }
166187
@@ -180,7 +201,7 @@ static __global__ void flash_attn_ext_f16(
180201#pragma unroll
181202 for (int k_KQ_0 = 0 ; k_KQ_0 < D; k_KQ_0 += 16 ) {
182203 frag_a_K K_a;
183- wmma::load_matrix_sync (K_a, K_h + int64_t (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx .y )*stride_KV + k_KQ_0, stride_KV);
204+ wmma::load_matrix_sync (K_a, K_h_f16 + int64_t (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx .y )*stride_KV + k_KQ_0, stride_KV);
184205#pragma unroll
185206 for (int j = 0 ; j < ncols/frag_n; ++j) {
186207 wmma::mma_sync (KQ_c[j], K_a, Q_b[k_KQ_0/16 ][j], KQ_c[j]);
@@ -310,7 +331,7 @@ static __global__ void flash_attn_ext_f16(
310331 const int k = k0 + (threadIdx .y % VKQ_ratio)*16 ;
311332 wmma::load_matrix_sync (
312333 KQ_b[k0/(VKQ_ratio*16 )][j0/frag_n],
313- KQ + j0*(kqar*kqs_padded) + k,
334+ KQ_f16 + j0*(kqar*kqs_padded) + k,
314335 kqar*kqs_padded);
315336 }
316337 }
@@ -328,7 +349,7 @@ static __global__ void flash_attn_ext_f16(
328349 const int k = k0 + (threadIdx .y % VKQ_ratio)*16 ;
329350
330351 frag_a_V v_a;
331- wmma::load_matrix_sync (v_a, V_h + int64_t (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx .y /VKQ_ratio), stride_KV);
352+ wmma::load_matrix_sync (v_a, V_h_f16 + int64_t (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx .y /VKQ_ratio), stride_KV);
332353#pragma unroll
333354 for (int j = 0 ; j < ncols/frag_n; ++j) {
334355 wmma::mma_sync (VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16 )][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
@@ -344,7 +365,7 @@ static __global__ void flash_attn_ext_f16(
344365#pragma unroll
345366 for (int j0 = 0 ; j0 < ncols; j0 += frag_n) {
346367 wmma::store_matrix_sync (
347- KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx .y /VKQ_ratio),
368+ KQ_f16 + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx .y /VKQ_ratio),
348369 VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
349370 D_padded, wmma::mem_col_major);
350371 }
0 commit comments