Skip to content

Commit c3a1128

Browse files
Optimize Metal Tensor API usage for matmul2d
Separates the Metal Tensor API (matmul2d) path in kernel_mul_mm into its own standalone kernel, gated by GGML_METAL_HAS_TENSOR. The legacy simdgroup_matrix kernel is preserved under #else. Previously both paths were interleaved via #ifdef blocks within a single kernel, forcing the tensor path to share the legacy kernel's data layout and threadgroup memory scheme. Splitting the kernel enabled memory and dispatch optimizations that weren't possible when the two paths shared code structure.
1 parent 9f102a1 commit c3a1128

6 files changed

Lines changed: 175 additions & 109 deletions

File tree

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -668,15 +668,21 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext(ggml_
668668
return res;
669669
}
670670

671-
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm(ggml_metal_library_t lib, const ggml_tensor * op) {
671+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm(ggml_metal_library_t lib, const ggml_tensor * op, bool has_tensor) {
672672
char base[256];
673673
char name[256];
674674

675675
const ggml_type tsrc0 = op->src[0]->type;
676676
const ggml_type tsrc1 = op->src[1]->type;
677677

678678
const bool bc_inp = op->src[0]->ne[0] % 32 != 0;
679-
const bool bc_out = op->ne[0] % 64 != 0 || op->ne[1] % 32 != 0;
679+
680+
constexpr int NRA = SZ_SIMDGROUP * N_MM_BLOCK_Y * N_MM_SIMD_GROUP_Y;
681+
constexpr int NRB = SZ_SIMDGROUP * N_MM_BLOCK_X * N_MM_SIMD_GROUP_X;
682+
683+
const bool bc_out = has_tensor
684+
? (op->ne[0] % NRA != 0 || op->ne[1] % NRB != 0)
685+
: (op->ne[0] % 64 != 0 || op->ne[1] % 32 != 0);
680686

681687
snprintf(base, 256, "kernel_mul_mm_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1));
682688
snprintf(name, 256, "%s_bci=%d_bco=%d", base, bc_inp, bc_out);
@@ -693,8 +699,14 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm(ggml_meta
693699
ggml_metal_cv_free(cv);
694700
}
695701

696-
// when the output size is not multiple of 64x32, we need extra smem to prevent out-of-bounds writes
697-
res.smem = bc_out ? 8192 : 4096 + 2048;
702+
if (has_tensor) {
703+
constexpr size_t NRA = SZ_SIMDGROUP * N_MM_BLOCK_Y * N_MM_SIMD_GROUP_Y; // 64
704+
const size_t smem_a = NRA * N_MM_NK_TOTAL * sizeof(ggml_fp16_t);
705+
res.smem = smem_a;
706+
} else {
707+
res.smem = bc_out ? 8192 : (4096 + 2048);
708+
}
709+
698710

699711
return res;
700712
}

ggml/src/ggml-metal/ggml-metal-device.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv
128128
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net (ggml_metal_library_t lib, const struct ggml_tensor * op);
129129
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri (ggml_metal_library_t lib, const struct ggml_tensor * op);
130130
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);
131-
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op);
131+
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op, bool has_tensor);
132132
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op);
133133
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id_map0 (ggml_metal_library_t lib, int ne02, int ne20);
134134
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id (ggml_metal_library_t lib, const struct ggml_tensor * op);

ggml/src/ggml-metal/ggml-metal-device.m

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -690,7 +690,7 @@ ggml_metal_device_t ggml_metal_device_init(int device) {
690690
" auto tB = B.slice((int)tgid.x, 0); \n"
691691
" \n"
692692
" matmul2d< \n"
693-
" matmul2d_descriptor(8, 8, dynamic_extent), \n"
693+
" matmul2d_descriptor(16, 16, dynamic_extent), \n"
694694
" execution_simdgroups<4>> mm; \n"
695695
" \n"
696696
" auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>(); \n"
@@ -699,7 +699,7 @@ ggml_metal_device_t ggml_metal_device_init(int device) {
699699
" auto sB = tB.slice(0, 0); \n"
700700
" mm.run(sB, sA, cT); \n"
701701
" \n"
702-
" auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(4, 4)); \n"
702+
" auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(16, 16)); \n"
703703
" \n"
704704
" cT.store(tC); \n"
705705
"}";
@@ -740,7 +740,7 @@ ggml_metal_device_t ggml_metal_device_init(int device) {
740740
" auto tB = B.slice((int)tgid.x, 0); \n"
741741
" \n"
742742
" matmul2d< \n"
743-
" matmul2d_descriptor(8, 8, dynamic_extent), \n"
743+
" matmul2d_descriptor(16, 16, dynamic_extent), \n"
744744
" execution_simdgroups<4>> mm; \n"
745745
" \n"
746746
" auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>(); \n"
@@ -749,7 +749,7 @@ ggml_metal_device_t ggml_metal_device_init(int device) {
749749
" auto sB = tB.slice(0, 0); \n"
750750
" mm.run(sB, sA, cT); \n"
751751
" \n"
752-
" auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(4, 4)); \n"
752+
" auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(16, 16)); \n"
753753
" \n"
754754
" cT.store(tC); \n"
755755
"}";

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,17 @@
88
//
99
// TODO: for optimal performance, become function of the device and work size
1010

11+
12+
#define SZ_SIMDGROUP 16
13+
#define N_MM_NK 2
14+
#define N_MM_NK_TOTAL (SZ_SIMDGROUP * N_MM_NK)
15+
16+
#define N_MM_BLOCK_X 4
17+
#define N_MM_BLOCK_Y 2
18+
#define N_MM_SIMD_GROUP_X 2
19+
#define N_MM_SIMD_GROUP_Y 2
20+
#define N_THREADS_PER_SIMDGROUP 32
21+
1122
#define N_R0_Q4_0 4
1223
#define N_SG_Q4_0 2
1324

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2155,7 +2155,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
21552155
// default: break;
21562156
//}
21572157

2158-
auto pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op);
2158+
auto pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op, props_dev->has_tensor);
21592159

21602160
ggml_metal_kargs_mul_mm args = {
21612161
/*.ne00 =*/ ne00,
@@ -2183,7 +2183,18 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
21832183
const size_t smem = pipeline.smem;
21842184

21852185
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
2186-
ggml_metal_encoder_dispatch_threadgroups(enc, ((ne11 + 31)/32), ((ne01 + 63)/64), ne12*ne13, 128, 1, 1);
2186+
2187+
if (props_dev->has_tensor) {
2188+
ggml_metal_encoder_dispatch_threadgroups(enc,
2189+
(ne11 + (SZ_SIMDGROUP * N_MM_SIMD_GROUP_X * N_MM_BLOCK_X) - 1) /
2190+
(SZ_SIMDGROUP * N_MM_SIMD_GROUP_X * N_MM_BLOCK_X),
2191+
(ne01 + (SZ_SIMDGROUP * N_MM_SIMD_GROUP_Y * N_MM_BLOCK_Y) - 1) /
2192+
(SZ_SIMDGROUP * N_MM_SIMD_GROUP_Y * N_MM_BLOCK_Y),
2193+
ne12 * ne13, N_THREADS_PER_SIMDGROUP * N_MM_SIMD_GROUP_X, N_MM_SIMD_GROUP_Y, 1);
2194+
} else {
2195+
ggml_metal_encoder_dispatch_threadgroups(enc, ((ne11 + 31) / 32), ((ne01 + 63) / 64), ne12 * ne13, 128, 1, 1);
2196+
}
2197+
21872198
} else {
21882199
auto pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op);
21892200

0 commit comments

Comments
 (0)