Skip to content

Commit 6845f7f

Browse files
authored
Add a workaround for compilation with ROCWMMA_FATTN and gfx9 (#19461)
There is an upstream problem [1] with AMD's LLVM 22 fork and rocWMMA 2.2.0 causing compilation issues on devices without native fp16 support (CDNA devices). The specialized types aren't resolved properly: ``` /opt/rocm/include/rocwmma/internal/mfma_impl.hpp:2549:37: error: ambiguous partial specializations of 'amdgcn_mfma<__half, __half, __half, 16, 16, 16>' 2549 | using ARegsT = typename Impl::ARegsT; ``` Add a workaround to explicitly declare the types and cast when compiling with HIP and ROCWMMA_FATTN [2]. When this is actually fixed upstream some guards can be used to detect and wrap the version that has the fix to only apply when necessary. Link: ROCm/rocm-libraries#4398 [1] Link: #19269 [2] Signed-off-by: Mario Limonciello <mario.limonciello@amd.com>
1 parent fa16e51 commit 6845f7f

1 file changed

Lines changed: 26 additions & 5 deletions

File tree

ggml/src/ggml-cuda/fattn-wmma-f16.cu

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,19 @@ static __global__ void flash_attn_ext_f16(
6363
constexpr int frag_m = ncols == 8 ? 32 : 16;
6464
constexpr int frag_n = ncols == 8 ? 8 : 16;
6565
static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
66+
#if defined(GGML_USE_HIP)
67+
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, _Float16, wmma::row_major> frag_a_K;
68+
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, _Float16, wmma::col_major> frag_a_V;
69+
typedef wmma::fragment<wmma::matrix_b, frag_m, frag_n, 16, _Float16, wmma::col_major> frag_b;
70+
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
71+
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, _Float16> frag_c_VKQ;
72+
#else
6673
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::row_major> frag_a_K;
6774
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::col_major> frag_a_V;
6875
typedef wmma::fragment<wmma::matrix_b, frag_m, frag_n, 16, half, wmma::col_major> frag_b;
6976
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
7077
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
78+
#endif
7179

7280
constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
7381
constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
@@ -126,6 +134,19 @@ static __global__ void flash_attn_ext_f16(
126134

127135
__shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
128136
half2 * VKQ2 = (half2 *) VKQ;
137+
138+
#if defined(GGML_USE_HIP)
139+
const _Float16 * K_h_f16 = reinterpret_cast<const _Float16 *>(K_h);
140+
const _Float16 * V_h_f16 = reinterpret_cast<const _Float16 *>(V_h);
141+
_Float16 * KQ_f16 = reinterpret_cast<_Float16 *>(KQ);
142+
_Float16 * VKQ_f16 = reinterpret_cast<_Float16 *>(VKQ);
143+
#else
144+
const half * K_h_f16 = K_h;
145+
const half * V_h_f16 = V_h;
146+
half * KQ_f16 = KQ;
147+
half * VKQ_f16 = VKQ;
148+
#endif
149+
129150
#pragma unroll
130151
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
131152
const int j = j0 + threadIdx.y;
@@ -160,7 +181,7 @@ static __global__ void flash_attn_ext_f16(
160181
for (int i0 = 0; i0 < D; i0 += 16) {
161182
#pragma unroll
162183
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
163-
wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
184+
wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ_f16 + j0*D_padded + i0, D_padded);
164185
}
165186
}
166187

@@ -180,7 +201,7 @@ static __global__ void flash_attn_ext_f16(
180201
#pragma unroll
181202
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
182203
frag_a_K K_a;
183-
wmma::load_matrix_sync(K_a, K_h + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
204+
wmma::load_matrix_sync(K_a, K_h_f16 + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
184205
#pragma unroll
185206
for (int j = 0; j < ncols/frag_n; ++j) {
186207
wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
@@ -310,7 +331,7 @@ static __global__ void flash_attn_ext_f16(
310331
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
311332
wmma::load_matrix_sync(
312333
KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
313-
KQ + j0*(kqar*kqs_padded) + k,
334+
KQ_f16 + j0*(kqar*kqs_padded) + k,
314335
kqar*kqs_padded);
315336
}
316337
}
@@ -328,7 +349,7 @@ static __global__ void flash_attn_ext_f16(
328349
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
329350

330351
frag_a_V v_a;
331-
wmma::load_matrix_sync(v_a, V_h + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
352+
wmma::load_matrix_sync(v_a, V_h_f16 + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
332353
#pragma unroll
333354
for (int j = 0; j < ncols/frag_n; ++j) {
334355
wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
@@ -344,7 +365,7 @@ static __global__ void flash_attn_ext_f16(
344365
#pragma unroll
345366
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
346367
wmma::store_matrix_sync(
347-
KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
368+
KQ_f16 + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
348369
VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
349370
D_padded, wmma::mem_col_major);
350371
}

0 commit comments

Comments
 (0)