@@ -343,16 +343,16 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
343343// N_DST, so this is another explicit assumption of the implementation.
344344template <typename block_q_type, int nr, int nsg, int nw>
345345void mul_vec_q_n_f32 (device const void * src0, device const float * src1, device float * dst,
346- int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, uint gqa,
346+ int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa,
347347 uint3 tgpig, uint tiisg, uint sgitg) {
348348 const int nb = ne00/QK4_0;
349349 const int r0 = tgpig.x ;
350350 const int r1 = tgpig.y ;
351351 const int im = tgpig.z ;
352352 const int first_row = (r0 * nsg + sgitg) * nr;
353- const uint offset0 = first_row * nb + im/gqa*(ne02/QK4_0 );
353+ const uint offset0 = first_row * nb + im/gqa*(nb*ne0 );
354354 device const block_q_type * x = (device const block_q_type *) src0 + offset0;
355- device const float * y = (device const float *) src1 + r1*ne10 + im*ne12 ;
355+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1 ;
356356 float yl[16 ]; // src1 vector cache
357357 float sumf[nr]={0 .f };
358358
@@ -383,7 +383,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
383383 for (int row = 0 ; row < nr; ++row) {
384384 const float tot = simd_sum (sumf[row]);
385385 if (tiisg == 0 && first_row + row < ne01) {
386- dst[r1*ne0 + im*ne12 + first_row + row] = tot;
386+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
387387 }
388388 }
389389}
@@ -398,11 +398,12 @@ kernel void kernel_mul_mat_q4_0_f32(
398398 constant int64_t & ne10[[buffer(9 )]],
399399 constant int64_t & ne12[[buffer(11 )]],
400400 constant int64_t & ne0[[buffer(15 )]],
401+ constant int64_t & ne1[[buffer(16 )]],
401402 constant uint & gqa[[buffer(17 )]],
402403 uint3 tgpig[[threadgroup_position_in_grid]],
403404 uint tiisg[[thread_index_in_simdgroup]],
404405 uint sgitg[[simdgroup_index_in_threadgroup]]) {
405- mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,gqa,tgpig,tiisg,sgitg);
406+ mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1, gqa,tgpig,tiisg,sgitg);
406407}
407408
408409kernel void kernel_mul_mat_q4_1_f32 (
@@ -415,11 +416,12 @@ kernel void kernel_mul_mat_q4_1_f32(
415416 constant int64_t & ne10[[buffer(9 )]],
416417 constant int64_t & ne12[[buffer(11 )]],
417418 constant int64_t & ne0[[buffer(15 )]],
419+ constant int64_t & ne1[[buffer(16 )]],
418420 constant uint & gqa[[buffer(17 )]],
419421 uint3 tgpig[[threadgroup_position_in_grid]],
420422 uint tiisg[[thread_index_in_simdgroup]],
421423 uint sgitg[[simdgroup_index_in_threadgroup]]) {
422- mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,gqa,tgpig,tiisg,sgitg);
424+ mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1, gqa,tgpig,tiisg,sgitg);
423425}
424426
425427kernel void kernel_mul_mat_f16_f32 (
@@ -800,6 +802,7 @@ kernel void kernel_mul_mat_q2_K_f32(
800802 constant int64_t & ne10[[buffer(9 )]],
801803 constant int64_t & ne12[[buffer(11 )]],
802804 constant int64_t & ne0[[buffer(15 )]],
805+ constant int64_t & ne1[[buffer(16 )]],
803806 constant uint & gqa[[buffer(17 )]],
804807 uint3 tgpig[[threadgroup_position_in_grid]],
805808 uint tiisg[[thread_index_in_simdgroup]],
@@ -812,9 +815,9 @@ kernel void kernel_mul_mat_q2_K_f32(
812815
813816 const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
814817 const int ib_row = first_row * nb;
815- const uint offset0 = r2/gqa*(ne02/QK_K );
818+ const uint offset0 = r2/gqa*(nb*ne0 );
816819 device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0;
817- device const float * y = (device const float *) src1 + r1*ne10 + r2*ne12 ;
820+ device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1 ;
818821 float yl[32 ];
819822 float sumf[N_DST]={0 .f }, all_sum;
820823
@@ -927,7 +930,7 @@ kernel void kernel_mul_mat_q2_K_f32(
927930 for (int row = 0 ; row < N_DST; ++row) {
928931 all_sum = simd_sum (sumf[row]);
929932 if (tiisg == 0 ) {
930- dst[r1*ne0 + r2*ne12 + first_row + row] = all_sum;
933+ dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
931934 }
932935 }
933936}
@@ -943,6 +946,7 @@ kernel void kernel_mul_mat_q3_K_f32(
943946 constant int64_t & ne10[[buffer(9 )]],
944947 constant int64_t & ne12[[buffer(11 )]],
945948 constant int64_t & ne0[[buffer(15 )]],
949+ constant int64_t & ne1[[buffer(16 )]],
946950 constant uint & gqa[[buffer(17 )]],
947951 uint3 tgpig[[threadgroup_position_in_grid]],
948952 uint tiisg[[thread_index_in_simdgroup]],
@@ -955,9 +959,9 @@ kernel void kernel_mul_mat_q3_K_f32(
955959 const int64_t r2 = tgpig.z ;
956960
957961 const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2 ;
958- const uint offset0 = r2/gqa*(ne02/QK_K );
962+ const uint offset0 = r2/gqa*(nb*ne0 );
959963 device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
960- device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne02 ;
964+ device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1 ;
961965
962966 float yl[16 ];
963967
@@ -1045,7 +1049,7 @@ kernel void kernel_mul_mat_q3_K_f32(
10451049 const float sumf = (sumf1[row] - 32 .f *sumf2[row]) / (1 << shift);
10461050 const float tot = simd_sum (sumf);
10471051 if (tiisg == 0 ) {
1048- dst[r1*ne0 + r2*ne12 + first_row + row] = tot;
1052+ dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
10491053 }
10501054 }
10511055}
@@ -1060,6 +1064,7 @@ kernel void kernel_mul_mat_q3_K_f32(
10601064 constant int64_t & ne10[[buffer(9 )]],
10611065 constant int64_t & ne12[[buffer(11 )]],
10621066 constant int64_t & ne0[[buffer(15 )]],
1067+ constant int64_t & ne1[[buffer(16 )]],
10631068 constant uint & gqa[[buffer(17 )]],
10641069 uint3 tgpig[[threadgroup_position_in_grid]],
10651070 uint tiisg[[thread_index_in_simdgroup]],
@@ -1072,9 +1077,9 @@ kernel void kernel_mul_mat_q3_K_f32(
10721077 const int64_t r2 = tgpig.z ;
10731078
10741079 const int row = 2 * r0 + sgitg;
1075- const uint offset0 = r2/gqa*(ne02/QK_K );
1080+ const uint offset0 = r2/gqa*(nb*ne0 );
10761081 device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0;
1077- device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne02 ;
1082+ device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1 ;
10781083 const int ix = tiisg/4 ;
10791084 const int il = 4 * (tiisg%4 );// 0, 4, 8, 12
10801085 const int im = il/8 ; // 0, 0, 1, 1
@@ -1113,7 +1118,7 @@ kernel void kernel_mul_mat_q3_K_f32(
11131118
11141119 const float tot = simd_sum (sumf);
11151120 if (tiisg == 0 ) {
1116- dst[r1*ne0 + r2*ne12 + row] = tot;
1121+ dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
11171122 }
11181123
11191124}
@@ -1130,6 +1135,7 @@ kernel void kernel_mul_mat_q4_K_f32(
11301135 constant int64_t & ne10[[buffer(9 )]],
11311136 constant int64_t & ne12[[buffer(11 )]],
11321137 constant int64_t & ne0[[buffer(15 )]],
1138+ constant int64_t & ne1[[buffer(16 )]],
11331139 constant uint & gqa[[buffer(17 )]],
11341140 uint3 tgpig[[threadgroup_position_in_grid]],
11351141 uint tiisg[[thread_index_in_simdgroup]],
@@ -1150,9 +1156,9 @@ kernel void kernel_mul_mat_q4_K_f32(
11501156 const int r2 = tgpig.z ;
11511157 const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
11521158 const int ib_row = first_row * nb;
1153- const uint offset0 = r2/gqa*(ne02/QK_K );
1159+ const uint offset0 = r2/gqa*(nb*ne0 );
11541160 device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
1155- device const float * y = (device const float *) src1 + r1*ne10 + r2*ne12 ;
1161+ device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1 ;
11561162 float yl[16 ];
11571163 float yh[16 ];
11581164 float sumf[N_DST]={0 .f }, all_sum;
@@ -1219,7 +1225,7 @@ kernel void kernel_mul_mat_q4_K_f32(
12191225 for (int row = 0 ; row < N_DST; ++row) {
12201226 all_sum = simd_sum (sumf[row]);
12211227 if (tiisg == 0 ) {
1222- dst[r1*ne0 + r2*ne12 + first_row + row] = all_sum;
1228+ dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
12231229 }
12241230 }
12251231}
@@ -1234,6 +1240,7 @@ kernel void kernel_mul_mat_q4_K_f32(
12341240 constant int64_t & ne10[[buffer(9 )]],
12351241 constant int64_t & ne12[[buffer(11 )]],
12361242 constant int64_t & ne0[[buffer(15 )]],
1243+ constant int64_t & ne1[[buffer(16 )]],
12371244 constant uint & gqa[[buffer(17 )]],
12381245 uint3 tgpig[[threadgroup_position_in_grid]],
12391246 uint tiisg[[thread_index_in_simdgroup]],
@@ -1248,9 +1255,9 @@ kernel void kernel_mul_mat_q4_K_f32(
12481255 const int r2 = tgpig.z ;
12491256 const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
12501257 const int ib_row = first_row * nb;
1251- const uint offset0 = r2/gqa*(ne02/QK_K );
1258+ const uint offset0 = r2/gqa*(nb*ne0 );
12521259 device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
1253- device const float * y = (device const float *) src1 + r1*ne10 + r2*ne12 ;
1260+ device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1 ;
12541261 float yl[8 ];
12551262 float yh[8 ];
12561263 float sumf[N_DST]={0 .f }, all_sum;
@@ -1306,7 +1313,7 @@ kernel void kernel_mul_mat_q4_K_f32(
13061313 for (int row = 0 ; row < N_DST; ++row) {
13071314 all_sum = simd_sum (sumf[row]);
13081315 if (tiisg == 0 ) {
1309- dst[r1*ne0+ r2*ne12 + first_row + row] = all_sum;
1316+ dst[r1*ne0+ r2*ne0*ne1 + first_row + row] = all_sum;
13101317 }
13111318 }
13121319}
@@ -1322,6 +1329,7 @@ kernel void kernel_mul_mat_q5_K_f32(
13221329 constant int64_t & ne10[[buffer(9 )]],
13231330 constant int64_t & ne12[[buffer(11 )]],
13241331 constant int64_t & ne0[[buffer(15 )]],
1332+ constant int64_t & ne1[[buffer(16 )]],
13251333 constant uint & gqa[[buffer(17 )]],
13261334 uint3 tgpig[[threadgroup_position_in_grid]],
13271335 uint tiisg[[thread_index_in_simdgroup]],
@@ -1334,9 +1342,9 @@ kernel void kernel_mul_mat_q5_K_f32(
13341342 const int r2 = tgpig.z ;
13351343
13361344 const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2 ;
1337- const uint offset0 = r2/gqa*(ne02/QK_K );
1345+ const uint offset0 = r2/gqa*(nb*ne0 );
13381346 device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0;
1339- device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne12 ;
1347+ device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1 ;
13401348
13411349 float sumf[2 ]={0 .f };
13421350
@@ -1470,7 +1478,7 @@ kernel void kernel_mul_mat_q5_K_f32(
14701478 for (int row = 0 ; row < 2 ; ++row) {
14711479 const float tot = simd_sum (sumf[row]);
14721480 if (tiisg == 0 ) {
1473- dst[r1*ne0 + r2*ne12 + first_row + row] = tot;
1481+ dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
14741482 }
14751483 }
14761484
@@ -1486,6 +1494,7 @@ kernel void kernel_mul_mat_q6_K_f32(
14861494 constant int64_t & ne10[[buffer(9 )]],
14871495 constant int64_t & ne12[[buffer(11 )]],
14881496 constant int64_t & ne0[[buffer(15 )]],
1497+ constant int64_t & ne1[[buffer(16 )]],
14891498 constant uint & gqa[[buffer(17 )]],
14901499 uint3 tgpig[[threadgroup_position_in_grid]],
14911500 uint tiisg[[thread_index_in_simdgroup]],
@@ -1503,9 +1512,9 @@ kernel void kernel_mul_mat_q6_K_f32(
15031512 const int r2 = tgpig.z ;
15041513
15051514 const int row = 2 * r0 + sgitg;
1506- const uint offset0 = r2/gqa*(ne02/QK_K );
1515+ const uint offset0 = r2/gqa*(nb*ne0 );
15071516 device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0;
1508- device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne12 ;
1517+ device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1 ;
15091518
15101519 float sumf = 0 ;
15111520
@@ -1571,7 +1580,7 @@ kernel void kernel_mul_mat_q6_K_f32(
15711580
15721581 const float tot = simd_sum (sumf);
15731582 if (tiisg == 0 ) {
1574- dst[r1*ne0 + r2*ne12 + row] = tot;
1583+ dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
15751584 }
15761585}
15771586
@@ -1835,7 +1844,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
18351844 uint offset0 = im/gqa*nb02; ushort offset1 = il/nl;
18361845 device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
18371846 device const float * y = src1 + (r1 * BLOCK_SIZE_N + thread_col) * ne00 \
1838- + BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne12 ;
1847+ + BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne00 * ne1 ;
18391848
18401849 for (int loop_k = 0 ; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
18411850 // load data and store to threadgroup memory
@@ -1880,7 +1889,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
18801889
18811890 if ((r0 + 1 ) * BLOCK_SIZE_M <= ne0 && (r1 + 1 ) * BLOCK_SIZE_N <= ne1) {
18821891 device float *C = dst + BLOCK_SIZE_M * r0 + 32 * (sgitg&1 ) \
1883- + (BLOCK_SIZE_N * r1 + 16 * (sgitg>>1 )) * ne0 + im*ne12 ;
1892+ + (BLOCK_SIZE_N * r1 + 16 * (sgitg>>1 )) * ne0 + im*ne1*ne0 ;
18841893 for (int i = 0 ; i < 8 ; i++) {
18851894 simdgroup_store (c_res[i], C + 8 * (i%4 ) + 8 * ne0 * (i/4 ), ne0);
18861895 }
@@ -1893,7 +1902,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
18931902 }
18941903
18951904 threadgroup_barrier (mem_flags::mem_threadgroup);
1896- device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne12 ;
1905+ device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0 ;
18971906 if (sgitg==0 ) {
18981907 for (int i = 0 ; i < n_rows; i++) {
18991908 for (int j = tiitg; j< n_cols; j += BLOCK_SIZE_N) {
0 commit comments