ggml-cpu: FA add GEMM microkernel#19422
Conversation
c4c451a to
734f76f
Compare
|
Hm, it looks like the "CPU reference" mechanism that we implemented in #19209 does not actually work as I thought it would. I was thinking that we can run: test-backend-ops -b CPUAnd it would compare the reference vs non-reference CPU implementation. But this is not the case, because the How do you test the non-reference implementation against the reference? |
|
That's what I use and it correctly fails if there is some bug. For example on master if I do something like this diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
index ed4535020..6223b202a 100644
--- a/ggml/src/ggml-cpu/ops.cpp
+++ b/ggml/src/ggml-cpu/ops.cpp
@@ -8247,7 +8247,7 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
if (s > M) {
ms = expf(M - s);
- M = s;
+ //M = s;
ggml_vec_scale_f32(DV, VKQ32, ms);
} else {
vs = expf(s - M);I get failures of the form
when running When I print pointers in |
|
Ah correct - this works as expected. The problem is that the tiled version is never exercised atm. |
|
Ah I see, you're right. So I can change in To exercise this path, though would have to be careful when tuning Q_TILE_SZ |
| #if defined(__AVX512F__) || defined (__ARM_NEON__) | ||
| static constexpr int GEMM_RM = 6; | ||
| static constexpr int GEMM_RN = 4; // 24+4+1 = 29/32 | ||
| #elif defined(__AVX2__) || defined(__AVX__) |
There was a problem hiding this comment.
On M2 Ultra, I get better results using:
# GGML_F32_EPR = 4
GEMM_RM = 4
GEMM_RN = 4Here is before and after:
- 6x4
| Model | Test | t/s master | t/s pr/19422 | Speedup |
|---|---|---|---|---|
| gpt-oss 20B MXFP4 MoE | pp512@d1024 | 166.40 | 177.12 | 1.06 |
| gpt-oss 20B MXFP4 MoE | pp512@d4096 | 134.16 | 155.12 | 1.16 |
| qwen2 1.5B F16 | pp512@d1024 | 298.68 | 307.97 | 1.03 |
| qwen2 1.5B F16 | pp512@d4096 | 226.77 | 247.86 | 1.09 |
| qwen2 3B F16 | pp512@d1024 | 143.64 | 151.70 | 1.06 |
| qwen2 3B F16 | pp512@d4096 | 112.81 | 131.39 | 1.16 |
- 4x4
| Model | Test | t/s master | t/s pr/19422 | Speedup |
|---|---|---|---|---|
| gpt-oss 20B MXFP4 MoE | pp512@d1024 | 167.38 | 178.00 | 1.06 |
| gpt-oss 20B MXFP4 MoE | pp512@d4096 | 134.28 | 160.09 | 1.19 |
| qwen2 1.5B F16 | pp512@d1024 | 299.04 | 313.51 | 1.05 |
| qwen2 1.5B F16 | pp512@d4096 | 226.24 | 259.23 | 1.15 |
| qwen2 3B F16 | pp512@d1024 | 144.24 | 152.88 | 1.06 |
| qwen2 3B F16 | pp512@d4096 | 113.59 | 135.72 | 1.19 |
There was a problem hiding this comment.
Thanks for testing, can you try one with a very high depth like 16384, that's where it would be really clear
There was a problem hiding this comment.
Also for me, after this PR FA=1 is always faster than FA=0 at least for PP. For TG the results are better if there are more threads. I guess our GEMV implementation can be improved
There was a problem hiding this comment.
Here is comparison of 6x4 (current) vs 4x4 (new):
| Model | Test | t/s 8debab3 | t/s dce1b0911 | Speedup |
|---|---|---|---|---|
| gpt-oss 20B MXFP4 MoE | pp512@d1024 | 176.48 | 178.33 | 1.01 |
| gpt-oss 20B MXFP4 MoE | pp512@d4096 | 156.05 | 159.91 | 1.02 |
| gpt-oss 20B MXFP4 MoE | pp512@d16384 | 106.38 | 113.16 | 1.06 |
| qwen2 1.5B F16 | pp512@d1024 | 307.25 | 313.73 | 1.02 |
| qwen2 1.5B F16 | pp512@d4096 | 247.65 | 258.67 | 1.04 |
| qwen2 1.5B F16 | pp512@d16384 | 136.64 | 153.64 | 1.12 |
| qwen2 3B F16 | pp512@d1024 | 151.11 | 152.89 | 1.01 |
| qwen2 3B F16 | pp512@d4096 | 130.56 | 133.45 | 1.02 |
| qwen2 3B F16 | pp512@d16384 | 85.33 | 91.88 | 1.08 |
I.e. it's better also for 16k context.
1b44835 to
d473b67
Compare
27aa928 to
c34b1a4
Compare
| #if defined(__GNUC__) && !defined(__clang__) | ||
| #pragma GCC diagnostic push | ||
| #pragma GCC diagnostic ignored "-Waggressive-loop-optimizations" | ||
| #endif |
There was a problem hiding this comment.
I noticed the compile warnings in the CI. Are we confident these are false-positives and safe to ignore?
There was a problem hiding this comment.
Yes, it was complaining about overflowing iteration when ii is 2^58
| // These are in units of GGML_F32_EPR | ||
| #if defined(__AVX512F__) || defined (__ARM_NEON__) | ||
| static constexpr int GEMM_RM = 4; | ||
| static constexpr int GEMM_RN = 4; // 16+4+1 = 25/32 |
There was a problem hiding this comment.
you can try
static constexpr int GEMM_RM = 6;
static constexpr int GEMM_RN = 4; // 24 + 4 + 1 = 29 ...that is what is used by AMD.
There was a problem hiding this comment.
It was this before #19422 (comment), but changed for ARM_NEON, will create a separate branch for AVX512
There was a problem hiding this comment.
#if defined(__AVX512F__)
static constexpr int GEMM_RM = 4;
static constexpr int GEMM_RN = 6; // 24+4+2 = 30/32 (+2 for pre-load)
#elif defined (__ARM_NEON__)
static constexpr int GEMM_RM = 4;
static constexpr int GEMM_RN = 4; // 16+4+1 = 25/32
| Bv[r] = GGML_F32_VEC_LOAD(B + kk * N + jj + r * KN); | ||
| } | ||
| for (int i = 0; i < RM; i++) { | ||
| GGML_F32_VEC p = GGML_F32_VEC_SET1(A[(ii + i) * K + kk]); |
There was a problem hiding this comment.
this is nice for x86 to do it like that.
but on ARM (neon) I remember that there is some OP:
regC += regB * regA[i] for FMA
it is possible to load a full register for A.
for FP16 : with 32 register on neon you can have:
- RN = 1
- RM = 16 // => 1 vector load
for fp32 it may be
1x8 / 2x8 / 1x16
[edit] but may need some transpose on A for that.
| #define GGML_FA_TILE_Q 32 | ||
| #define GGML_FA_TILE_KV 16 | ||
| #define GGML_FA_TILE_Q 64 | ||
| #define GGML_FA_TILE_KV 64 |
There was a problem hiding this comment.
did this not need to be adjust with GEMM_RM/GEMM_RN size?
but I don't know if it is related to GEMM_RM or GEMM_RN (or both?)
There was a problem hiding this comment.
It should be adjusted acc to ISA available, I just tuned on Zen 2 (AMD Rome) since that's the only one I have available. I think scratch space should be close to L1 cache size so maybe that's one factor to tune this
There was a problem hiding this comment.
yes, but I think something else: it is bloc of bloc so
GGML_FA_TILE_Q = 16*GEMM_RM; // or 16*GEMM_RN
GGML_FA_TILE_KV = 16*GEMM_RM; // or 16*GEMM_RNso we use at most the most efficient uGEMM blick size ?
Note: I did not take time to look it GGML_FA_TILE_Q and GGML_FA_TILE_KV is related to GEMM_RM or GEMM_RN:
And yes for L1/L2/L3 cache size, but to be best we need block on K not only NxM
There was a problem hiding this comment.
Right, I guess for AVX2 it works out because GEMM_RM=4, 4*16 = 64. Let me try to tune it according to this and see the difference
There was a problem hiding this comment.
what bench cmd did you use ?
There was a problem hiding this comment.
-m llama_8b_q4_0.gguf,gpt-oss-20b.mxfp4 -fa 1 -d 0,8196,16348 -t 16,32,64 -n 0 -p 512
There was a problem hiding this comment.
no -ctk f32 -ctv f32 is needed ?
There was a problem hiding this comment.
No because it converts for f32 from f16
| } | ||
| for (; jj + KN <= N; jj += KN) { | ||
| simd_gemm_ukernel<1, 1>(C, A, B, K, N, ii, jj); | ||
| } |
There was a problem hiding this comment.
for some gain (or not...)
tinyblas use some encoding 0xMN to have a switch on "all" MxN possible
so we can use some <2,GEMM_RN> ...
llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp
Line 1224 in 184c694
but yes it need more dev / time / ...
|
I did some bench with zen5x16 (a IA-MAX+ 395)
For now on this my zen5 with llama 8B the faster I have is: If I can find some time I like to "hack" your FA/gemm for BF16 support.... Do you think (or @ggerganov ) it is possible to store the KV-cache packed? so we don"t need to do full repack in FA ? |
|
@Djip007 The primary limitation for FA modifications is the code to remain easy to understand and test.
Not sure, haven't thought about this. My feeling is even if it is possible, it wouldn't be worth the extra complexity. |
I completely agree with that, and I saw your reservations about the testing part. I haven't thought about it very long, so it might not be possible: We could extend the extra buffers to handle KV in addition to weight cases. It might not be a good idea to discuss on this pull request. Should I start a new discussion on this topic? And it may not be possible or very complicated 😎 |
* ggml-cpu: FA add GEMM microkernel * add guard for sizeless vector types * fix case where DV % GGML_F32_EPR !=0 * move memset out of the loop * move another memset out of the loop * use RM=4 for arm * simd_gemm: convert everything to int * convert everything to size_t to avoid warnings * fixup * add pragma for ignoring aggressive loop optimizations
* ggml-cpu: FA add GEMM microkernel * add guard for sizeless vector types * fix case where DV % GGML_F32_EPR !=0 * move memset out of the loop * move another memset out of the loop * use RM=4 for arm * simd_gemm: convert everything to int * convert everything to size_t to avoid warnings * fixup * add pragma for ignoring aggressive loop optimizations
* ggml-cpu: FA add GEMM microkernel * add guard for sizeless vector types * fix case where DV % GGML_F32_EPR !=0 * move memset out of the loop * move another memset out of the loop * use RM=4 for arm * simd_gemm: convert everything to int * convert everything to size_t to avoid warnings * fixup * add pragma for ignoring aggressive loop optimizations
This PR contains the following improvements for the tiled FA kernel
Future work would be adding a f16 version for hardware that supports it.
Results on 64c EPYC server, similar speed-ups for 16/32 cores.
AI disclosure: I wrote the register blocked micro kernel for AVX2, but I let AI handle rest of the kernel.