Skip to content

ggml-cpu: FA add GEMM microkernel#19422

Merged
am17an merged 10 commits intoggml-org:masterfrom
am17an:opt-fa-micro-gemm
Feb 15, 2026
Merged

ggml-cpu: FA add GEMM microkernel#19422
am17an merged 10 commits intoggml-org:masterfrom
am17an:opt-fa-micro-gemm

Conversation

@am17an
Copy link
Contributor

@am17an am17an commented Feb 7, 2026

This PR contains the following improvements for the tiled FA kernel

  • Add a simd gemm for float32 in the tiled FA kernel.
  • Tune tile sizes for larger context
  • Remove condition that kv depth % KV_TILE_SZ == 0

Future work would be adding a f16 version for hardware that supports it.

Results on 64c EPYC server, similar speed-ups for 16/32 cores.

Model Test t/s master t/s opt-fa-micro-gemm Speedup
gpt-oss 20B MXFP4 MoE pp512 234.75 237.96 1.01
gpt-oss 20B MXFP4 MoE pp512@d1024 201.78 219.03 1.09
gpt-oss 20B MXFP4 MoE pp512@d2048 190.70 207.59 1.09
gpt-oss 20B MXFP4 MoE pp512@d4096 167.51 195.65 1.17
gpt-oss 20B MXFP4 MoE pp512@d8192 131.76 163.66 1.24
gpt-oss 20B MXFP4 MoE pp512@d16384 88.44 130.36 1.47
gpt-oss 20B MXFP4 MoE pp512@d32768 56.46 91.62 1.62
llama 8B Q4_K_M pp512 184.26 180.40 0.98
llama 8B Q4_K_M pp512@d1024 170.39 182.02 1.07
llama 8B Q4_K_M pp512@d2048 155.95 162.77 1.04
llama 8B Q4_K_M pp512@d4096 124.17 132.95 1.07
llama 8B Q4_K_M pp512@d8192 90.52 103.16 1.14
llama 8B Q4_K_M pp512@d16384 56.15 79.88 1.42
llama 8B Q4_K_M pp512@d32768 31.33 49.53 1.58

AI disclosure: I wrote the register blocked micro kernel for AVX2, but I let AI handle rest of the kernel.

@am17an am17an requested a review from ggerganov as a code owner February 7, 2026 19:02
@github-actions github-actions bot added the ggml changes relating to the ggml tensor library for machine learning label Feb 7, 2026
@ggerganov
Copy link
Member

Hm, it looks like the "CPU reference" mechanism that we implemented in #19209 does not actually work as I thought it would. I was thinking that we can run:

test-backend-ops -b CPU

And it would compare the reference vs non-reference CPU implementation. But this is not the case, because the ggml_backend_cpu_set_use_ref() applies to both backend1 and backend2 - they are the same backends.

How do you test the non-reference implementation against the reference?

@am17an
Copy link
Contributor Author

am17an commented Feb 13, 2026

That's what I use and it correctly fails if there is some bug. For example on master if I do something like this

diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
index ed4535020..6223b202a 100644
--- a/ggml/src/ggml-cpu/ops.cpp
+++ b/ggml/src/ggml-cpu/ops.cpp
@@ -8247,7 +8247,7 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
 
             if (s > M) {
                 ms = expf(M - s);
-                M = s;
+                //M = s;
                 ggml_vec_scale_f32(DV, VKQ32, ms);
             } else {
                 vs = expf(s - M);

I get failures of the form

[FLASH_ATTN_EXT] ERR = 0.135610349 > 0.000500000 FLASH_ATTN_EXT(hsk=40,hsv=40,nh=4,nr23=[4,1],kv=512,nb=1,mask=0,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]): FAIL

when running test-backend-ops -b CPU

When I print pointers in make_test_cases_eval, I get different pointers for backend and backend_cpu

@ggerganov
Copy link
Member

Ah correct - this works as expected.

The problem is that the tiled version is never exercised atm. use_tiled is always false because in all tests the batch size is less than Q_TILE_SZ.

@am17an
Copy link
Contributor Author

am17an commented Feb 13, 2026

Ah I see, you're right. So I can change in test-backend-ops:

-                                            for (int nb : { 1, 3, 32, 35, }) {
+                                            for (int nb : { 1, 3, 32, 65, }) {

To exercise this path, though would have to be careful when tuning Q_TILE_SZ

Comment on lines +15 to +18
#if defined(__AVX512F__) || defined (__ARM_NEON__)
static constexpr int GEMM_RM = 6;
static constexpr int GEMM_RN = 4; // 24+4+1 = 29/32
#elif defined(__AVX2__) || defined(__AVX__)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On M2 Ultra, I get better results using:

# GGML_F32_EPR = 4

GEMM_RM = 4
GEMM_RN = 4

Here is before and after:

  • 6x4
Model Test t/s master t/s pr/19422 Speedup
gpt-oss 20B MXFP4 MoE pp512@d1024 166.40 177.12 1.06
gpt-oss 20B MXFP4 MoE pp512@d4096 134.16 155.12 1.16
qwen2 1.5B F16 pp512@d1024 298.68 307.97 1.03
qwen2 1.5B F16 pp512@d4096 226.77 247.86 1.09
qwen2 3B F16 pp512@d1024 143.64 151.70 1.06
qwen2 3B F16 pp512@d4096 112.81 131.39 1.16
  • 4x4
Model Test t/s master t/s pr/19422 Speedup
gpt-oss 20B MXFP4 MoE pp512@d1024 167.38 178.00 1.06
gpt-oss 20B MXFP4 MoE pp512@d4096 134.28 160.09 1.19
qwen2 1.5B F16 pp512@d1024 299.04 313.51 1.05
qwen2 1.5B F16 pp512@d4096 226.24 259.23 1.15
qwen2 3B F16 pp512@d1024 144.24 152.88 1.06
qwen2 3B F16 pp512@d4096 113.59 135.72 1.19

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for testing, can you try one with a very high depth like 16384, that's where it would be really clear

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also for me, after this PR FA=1 is always faster than FA=0 at least for PP. For TG the results are better if there are more threads. I guess our GEMV implementation can be improved

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is comparison of 6x4 (current) vs 4x4 (new):

Model Test t/s 8debab3 t/s dce1b0911 Speedup
gpt-oss 20B MXFP4 MoE pp512@d1024 176.48 178.33 1.01
gpt-oss 20B MXFP4 MoE pp512@d4096 156.05 159.91 1.02
gpt-oss 20B MXFP4 MoE pp512@d16384 106.38 113.16 1.06
qwen2 1.5B F16 pp512@d1024 307.25 313.73 1.02
qwen2 1.5B F16 pp512@d4096 247.65 258.67 1.04
qwen2 1.5B F16 pp512@d16384 136.64 153.64 1.12
qwen2 3B F16 pp512@d1024 151.11 152.89 1.01
qwen2 3B F16 pp512@d4096 130.56 133.45 1.02
qwen2 3B F16 pp512@d16384 85.33 91.88 1.08

I.e. it's better also for 16k context.

@am17an am17an merged commit 684b361 into ggml-org:master Feb 15, 2026
78 checks passed
@am17an am17an deleted the opt-fa-micro-gemm branch February 15, 2026 05:39
Comment on lines +26 to +29
#if defined(__GNUC__) && !defined(__clang__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Waggressive-loop-optimizations"
#endif
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed the compile warnings in the CI. Are we confident these are false-positives and safe to ignore?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it was complaining about overflowing iteration when ii is 2^58

// These are in units of GGML_F32_EPR
#if defined(__AVX512F__) || defined (__ARM_NEON__)
static constexpr int GEMM_RM = 4;
static constexpr int GEMM_RN = 4; // 16+4+1 = 25/32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can try

static constexpr int GEMM_RM = 6;
static constexpr int GEMM_RN = 4;   // 24 + 4 + 1 = 29 ...

that is what is used by AMD.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was this before #19422 (comment), but changed for ARM_NEON, will create a separate branch for AVX512

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#if defined(__AVX512F__)
    static constexpr int GEMM_RM = 4;
    static constexpr int GEMM_RN = 6; // 24+4+2 = 30/32  (+2 for pre-load)
#elif defined (__ARM_NEON__)
    static constexpr int GEMM_RM = 4;
    static constexpr int GEMM_RN = 4; // 16+4+1 = 25/32

Bv[r] = GGML_F32_VEC_LOAD(B + kk * N + jj + r * KN);
}
for (int i = 0; i < RM; i++) {
GGML_F32_VEC p = GGML_F32_VEC_SET1(A[(ii + i) * K + kk]);
Copy link
Contributor

@Djip007 Djip007 Feb 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is nice for x86 to do it like that.

but on ARM (neon) I remember that there is some OP:
regC += regB * regA[i] for FMA
it is possible to load a full register for A.
for FP16 : with 32 register on neon you can have:

  • RN = 1
  • RM = 16 // => 1 vector load

for fp32 it may be
1x8 / 2x8 / 1x16

[edit] but may need some transpose on A for that.

#define GGML_FA_TILE_Q 32
#define GGML_FA_TILE_KV 16
#define GGML_FA_TILE_Q 64
#define GGML_FA_TILE_KV 64
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did this not need to be adjust with GEMM_RM/GEMM_RN size?

but I don't know if it is related to GEMM_RM or GEMM_RN (or both?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be adjusted acc to ISA available, I just tuned on Zen 2 (AMD Rome) since that's the only one I have available. I think scratch space should be close to L1 cache size so maybe that's one factor to tune this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, but I think something else: it is bloc of bloc so

GGML_FA_TILE_Q   = 16*GEMM_RM;  // or 16*GEMM_RN 
GGML_FA_TILE_KV = 16*GEMM_RM;  // or 16*GEMM_RN

so we use at most the most efficient uGEMM blick size ?

Note: I did not take time to look it GGML_FA_TILE_Q and GGML_FA_TILE_KV is related to GEMM_RM or GEMM_RN:

And yes for L1/L2/L3 cache size, but to be best we need block on K not only NxM

Copy link
Contributor Author

@am17an am17an Feb 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I guess for AVX2 it works out because GEMM_RM=4, 4*16 = 64. Let me try to tune it according to this and see the difference

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what bench cmd did you use ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-m llama_8b_q4_0.gguf,gpt-oss-20b.mxfp4 -fa 1 -d 0,8196,16348 -t 16,32,64 -n 0 -p 512

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no -ctk f32 -ctv f32 is needed ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No because it converts for f32 from f16

}
for (; jj + KN <= N; jj += KN) {
simd_gemm_ukernel<1, 1>(C, A, B, K, N, ii, jj);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for some gain (or not...)
tinyblas use some encoding 0xMN to have a switch on "all" MxN possible
so we can use some <2,GEMM_RN> ...

int64_t mc, nc, mp, np;

but yes it need more dev / time / ...

@Djip007
Copy link
Contributor

Djip007 commented Feb 16, 2026

I did some bench with zen5x16 (a IA-MAX+ 395)

  • the AOCL use aocl AMD hight optimized gemm full BF16.
  • FA-0_BF16 use std tinyblas with KV as BF16 no FA
  • FA-0_FP16 use std tinyblas with KV as FP16 no FA
    les nexts use FA with FP16 KV: I test some TILE size :
    GGML_FA_TILE_Q/GGML_FA_TILE_KV GEMM_RMxGEMM_RN
model test AOCL_BF16 t/s FA-0_BF16 t/s FA-0_FP16 t/s 64/64 6x4 t/s 64/64 4x4 t/s 48/64 6x4 t/s 480/128 6x4 t/s 300x64 6x4 t/s 384/64 4x4 t/s
llama 8B BF16 pp512 428.27 ± 6.89 277.67 ± 0.20 268.62 ± 0.12 300.69 ± 0.20 298.92 ± 0.06 295.09 ± 0.13 293.63 ± 1.13 300.13 ± 0.28 295.73 ± 0.13
llama 8B BF16 pp512 @ d2048 348.87 ± 0.13 225.60 ± 0.24 215.64 ± 0.07 242.19 ± 0.44 224.33 ± 0.39 253.65 ± 0.55 244.22 ± 0.16
llama 8B BF16 pp512 @ d4096 289.80 ± 1.57 185.74 ± 0.33 177.57 ± 0.82 204.15 ± 0.20 206.08 ± 0.13 196.40 ± 0.32 184.88 ± 0.83 220.96 ± 0.46 208.83 ± 0.56
llama 8B BF16 pp512 @ d8192 231.30 ± 1.11 141.81 ± 0.51 133.66 ± 0.06 125.32 ± 0.21 127.95 ± 0.10 124.05 ± 0.75 136.39 ± 0.25 171.49 ± 0.14 156.20 ± 0.01
llama 8B BF16 pp512 @ d16384 161.54 ± 0.89 97.82 ± 0.40 86.70 ± 0.21 75.05 ± 0.08 78.68 ± 0.01 76.93 ± 0.47 90.11 ± 0.82 116.39 ± 0.19 105.02 ± 0.13
llama 8B BF16 pp512 @ d32768 98.66 ± 0.80 53.82 ± 0.10 48.78 ± 0.00 43.92 ± 0.11 50.61 ± 0.16 72.62 ± 0.21 62.83 ± 0.24

For now on this my zen5 with llama 8B the faster I have is:

GGML_FA_TILE_Q 300
GGML_FA_TILE_KV 64
GEMM_RM 6
GEMM_RN 4

If I can find some time I like to "hack" your FA/gemm for BF16 support....

Do you think (or @ggerganov ) it is possible to store the KV-cache packed? so we don"t need to do full repack in FA ?

@ggerganov
Copy link
Member

@Djip007 The primary limitation for FA modifications is the code to remain easy to understand and test.

Do you think (or @ggerganov ) it is possible to store the KV-cache packed? so we don"t need to do full repack in FA ?

Not sure, haven't thought about this. My feeling is even if it is possible, it wouldn't be worth the extra complexity.

@Djip007
Copy link
Contributor

Djip007 commented Feb 17, 2026

My feeling is even if it is possible, it wouldn't be worth the extra complexity.

I completely agree with that, and I saw your reservations about the testing part.

I haven't thought about it very long, so it might not be possible:

We could extend the extra buffers to handle KV in addition to weight cases.
This would allow us to keep the simple case in the basic CPU backend and add optimized versions of the FA, as was done for GEMM.
Well, we might also need to extend the tests to work in these cases.

It might not be a good idea to discuss on this pull request. Should I start a new discussion on this topic?

And it may not be possible or very complicated 😎

liparetejas pushed a commit to liparetejas/llama.cpp that referenced this pull request Feb 23, 2026
* ggml-cpu: FA add GEMM microkernel

* add guard for sizeless vector types

* fix case where DV % GGML_F32_EPR !=0

* move memset out of the loop

* move another memset out of the loop

* use RM=4 for arm

* simd_gemm: convert everything to int

* convert everything to size_t to avoid warnings

* fixup

* add pragma for ignoring aggressive loop optimizations
bartowski1182 pushed a commit to bartowski1182/llama.cpp that referenced this pull request Mar 2, 2026
* ggml-cpu: FA add GEMM microkernel

* add guard for sizeless vector types

* fix case where DV % GGML_F32_EPR !=0

* move memset out of the loop

* move another memset out of the loop

* use RM=4 for arm

* simd_gemm: convert everything to int

* convert everything to size_t to avoid warnings

* fixup

* add pragma for ignoring aggressive loop optimizations
ArberSephirotheca pushed a commit to ArberSephirotheca/llama.cpp that referenced this pull request Mar 3, 2026
* ggml-cpu: FA add GEMM microkernel

* add guard for sizeless vector types

* fix case where DV % GGML_F32_EPR !=0

* move memset out of the loop

* move another memset out of the loop

* use RM=4 for arm

* simd_gemm: convert everything to int

* convert everything to size_t to avoid warnings

* fixup

* add pragma for ignoring aggressive loop optimizations
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants