Skip to content

ggml-cpu: FA split across kv for faster TG#19209

Merged
am17an merged 3 commits intoggml-org:masterfrom
am17an:opt-fa-decode
Feb 2, 2026
Merged

ggml-cpu: FA split across kv for faster TG#19209
am17an merged 3 commits intoggml-org:masterfrom
am17an:opt-fa-decode

Conversation

@am17an
Copy link
Contributor

@am17an am17an commented Jan 30, 2026

Continuing on #19012, In the FA CPU implementation we don't parallelize across the context size, so each thread reads the entire sequence for calculating it's values. This PR introduces chunking across the context size, where each thread maintains partial accumulators for the sum, maximum, and running VKQ sum for the the soft-max calculation and calculates this for all query heads. There is an extra reduction step in the end to reduce the partials. Original idea from https://pytorch.org/blog/flash-decoding/ but then @JohannesGaessler pointed out that the CUDA FA kernel already does this.

Tested on three models and the results are good, larger (>2x) speed-ups as context size grows. Note at lower contexts the results are a bit noisy. Also note that going from 32->64 cores on master doesn't make a difference in some cases because we only parallelize the query heads which can be < n_threads

benchmark_comparison_log

@am17an am17an requested a review from ggerganov as a code owner January 30, 2026 15:35
@github-actions github-actions bot added the ggml changes relating to the ggml tensor library for machine learning label Jan 30, 2026
@ggerganov
Copy link
Member

I am a bit worried that the reference FA implementation will get increasingly complex and it's important to be simple enough since this is one of the most evolving ops. If we make an error during these optimizations, we currently don't have a way to verify this.

Think we should implement a way to validate that the CPU computation is correct first - for example running against the basic vec-kernel implementation. Have a few different ideas, but still not sure which one is the best.

@am17an
Copy link
Contributor Author

am17an commented Jan 30, 2026

I ran this through the CUDA tests to catch bugs, since the CUDA implementation can be considered correct in this case.

@am17an am17an changed the title ggml-cpu: split across kv for faster TG ggml-cpu: FA split across kv for faster TG Jan 31, 2026
@am17an
Copy link
Contributor Author

am17an commented Feb 2, 2026

@ggerganov if the CI passes, does it not mean the reference can be updated to the new implementation? As basically we are saying all backends agree on this new reference (which matches the old reference). Or do you want to test things which are not covered in the CI?

@JohannesGaessler
Copy link
Contributor

Long-term I think it would be useful to add something like test-flash-attn.cpp that has a bare-bones single-threaded FA implementation that just loops over tensor->data in CPU user code and compares the result against the ggml CPU backend. I agree that in the short term, if test-backend-ops is passing, that would indicate that the implementation in this PR is correct. But historically speaking FA has been the most error-prone ggml op and we need to be able to determine whether the CPU implementation or the other implementation are correct if and when issues arise.

@am17an
Copy link
Contributor Author

am17an commented Feb 2, 2026

I can add that in this PR if it helps. What I think would be the most useful is comparing the non flash attn output to the flash attn output on CPU. That way we are free to add optimizations to the flash attention path if it matches with the non flash attn path.

@JohannesGaessler
Copy link
Contributor

There are some features though like KV cache quantization that as of right now strictly require FA - but if that is the only thing we're missing I think it would be fine. Can you check whether there is anything else that we could not cover like that?

@ggerganov
Copy link
Member

What I think would be the most useful is comparing the non flash attn output to the flash attn output on CPU.

I think that implementing the non-FA path will become very complicated.

The FA reference should be the FA vec implementation as it is very simple and all changes to the FA operator in the future should start with modifying that implementation. We just have to figure out some refactoring that would allow to execute this reference code in the test program - I am not sure what is the best way to do it.

@am17an
Copy link
Contributor Author

am17an commented Feb 2, 2026

What about an env variable which forces the vec kernel and then a logprobs test between this and the normal path?

EDIT: I think log probs test is not comprehensive enough, ideally test-backend-ops should be used

@ggerganov
Copy link
Member

With test-backend-ops we can already run CPU vs CPU comparison like this:

test-backend-ops -o FLASH_ATTN_EXT -b CPU

It runs two CPU backends against each other. We just have to figure out a way to make one of the runs to use the vec kernel.

For example, we can:

  • add op param to ggml_flash_attn_ext() to signal this optionally
  • add new ggml_tensor_flag alternatively

Either of those would work, though neither seems perfect.

@am17an
Copy link
Contributor Author

am17an commented Feb 2, 2026

The op_params I think it's non-ideal but works. Another idea I have is adding a reference flag to the cpu-backend, which can be used to dispatch reference implementation or the optimized ones. Does that seem like a better idea? We can re-use that flag for other operations we modify as well

EDIT: I don't think that's possible at the moment since we only get ggml_compute_params while computing

@ggerganov
Copy link
Member

Yes, a flag to the cpu backend seems good. Though it would require a bit of plumbing. Roughly, something like this:

  • Register a new CPU-specific function similar to ggml_backend_cpu_set_n_threads for setting the reference flag
  • In ggml_backend_cpu_graph_compute() extend the ggml_graph_plan() to accept this new flag and keep it in the struct ggml_cplan struct
  • Extend struct ggml_compute_params with a flag and populate it in ggml_graph_compute_thread
  • Check the flag in ggml_compute_forward_flash_attn_ext() to decide which kernel to run

In test-backend-ops we will then look for the new backend function and call it for the backend_cpu to make it run the reference path:

ggml_backend_t backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, NULL);
if (backend_cpu == NULL) {
test_operation_info info("", "", "CPU");
info.set_error("backend", "Failed to initialize CPU backend");
output_printer->print_operation(info);
return false;
}

@am17an
Copy link
Contributor Author

am17an commented Feb 2, 2026

I added something like this in 7adc22d, let me know if looks ok and I will push it to this branch

@ggerganov
Copy link
Member

Looks good - added a few comments to the commit

@github-actions github-actions bot added the testing Everything test related label Feb 2, 2026
@am17an am17an force-pushed the opt-fa-decode branch 2 times, most recently from efe83e1 to 88c5fa6 Compare February 2, 2026 11:48
@ggerganov
Copy link
Member

We just need to update the cpu-high-perf workflows to explicitly run test-backend-ops -b CPU in order to exercise this. I can do it in a follow-up PR if you are unsure how to modify the CI (need to define a new env var GG_BUILD_HIGH_PERF=1 and use it in ci/run.sh).

@am17an am17an merged commit 9f682fb into ggml-org:master Feb 2, 2026
68 of 75 checks passed
@am17an am17an deleted the opt-fa-decode branch February 2, 2026 17:20
@Djip007
Copy link
Contributor

Djip007 commented Feb 2, 2026

What is the perf off this OP vs no FA with bf16 kv on zen4/5 ?
If I read it it use FP32 compute, so 1/2 of what we can have with KF on BF16 ?

and it is a longtime that with BF16/llamafile matmul kernel that the KV is a lot slower than without

on a IA-MAX+ 395 (ie 16 core zen5 I get:

model size params backend threads type_k type_v test t/s
llama 8B BF16 14.96 GiB 8.03 B CPU 16 bf16 bf16 pp512 289.71 ± 0.08
llama 8B BF16 14.96 GiB 8.03 B CPU 16 bf16 bf16 tg16 7.67 ± 0.00
llama 8B BF16 14.96 GiB 8.03 B CPU 16 bf16 bf16 pp512 @ d512 265.57 ± 0.67
llama 8B BF16 14.96 GiB 8.03 B CPU 16 bf16 bf16 tg16 @ d512 7.65 ± 0.00
llama 8B BF16 14.96 GiB 8.03 B CPU 16 bf16 bf16 pp512 @ d1024 253.52 ± 0.27
llama 8B BF16 14.96 GiB 8.03 B CPU 16 bf16 bf16 tg16 @ d1024 7.56 ± 0.00
llama 8B BF16 14.96 GiB 8.03 B CPU 16 bf16 bf16 pp512 @ d2048 229.87 ± 0.17
llama 8B BF16 14.96 GiB 8.03 B CPU 16 bf16 bf16 tg16 @ d2048 7.43 ± 0.00
llama 8B BF16 14.96 GiB 8.03 B CPU 16 bf16 bf16 pp512 @ d4096 192.56 ± 0.37
llama 8B BF16 14.96 GiB 8.03 B CPU 16 bf16 bf16 tg16 @ d4096 7.13 ± 0.00
llama 8B BF16 14.96 GiB 8.03 B CPU 16 bf16 bf16 pp512 @ d8192 145.27 ± 0.23
llama 8B BF16 14.96 GiB 8.03 B CPU 16 bf16 bf16 tg16 @ d8192 6.67 ± 0.00
llama 8B BF16 14.96 GiB 8.03 B CPU 16 bf16 bf16 pp512 @ d16384 98.36 ± 0.44
llama 8B BF16 14.96 GiB 8.03 B CPU 16 bf16 bf16 tg16 @ d16384 5.87 ± 0.00
llama 8B BF16 14.96 GiB 8.03 B CPU 16 bf16 bf16 pp512 @ d32768 53.98 ± 0.02
llama 8B BF16 14.96 GiB 8.03 B CPU 16 bf16 bf16 tg16 @ d32768 4.70 ± 0.00
model size params backend threads fa test t/s
llama 8B BF16 14.96 GiB 8.03 B CPU 16 1 pp512 284.90 ± 0.30
llama 8B BF16 14.96 GiB 8.03 B CPU 16 1 tg16 7.77 ± 0.00
llama 8B BF16 14.96 GiB 8.03 B CPU 16 1 pp512 @ d512 254.18 ± 0.79
llama 8B BF16 14.96 GiB 8.03 B CPU 16 1 tg16 @ d512 7.64 ± 0.00
llama 8B BF16 14.96 GiB 8.03 B CPU 16 1 pp512 @ d1024 203.29 ± 0.23
llama 8B BF16 14.96 GiB 8.03 B CPU 16 1 tg16 @ d1024 7.57 ± 0.00
llama 8B BF16 14.96 GiB 8.03 B CPU 16 1 pp512 @ d2048 162.42 ± 0.14
llama 8B BF16 14.96 GiB 8.03 B CPU 16 1 tg16 @ d2048 7.41 ± 0.00
llama 8B BF16 14.96 GiB 8.03 B CPU 16 1 pp512 @ d4096 115.75 ± 0.17
llama 8B BF16 14.96 GiB 8.03 B CPU 16 1 tg16 @ d4096 7.01 ± 0.00
llama 8B BF16 14.96 GiB 8.03 B CPU 16 1 pp512 @ d8192 82.85 ± 0.08
llama 8B BF16 14.96 GiB 8.03 B CPU 16 1 tg16 @ d8192 6.32 ± 0.00
llama 8B BF16 14.96 GiB 8.03 B CPU 16 1 pp512 @ d16384 46.11 ± 0.02
llama 8B BF16 14.96 GiB 8.03 B CPU 16 1 tg16 @ d16384 5.25 ± 0.03
llama 8B BF16 14.96 GiB 8.03 B CPU 16 1 pp512 @ d32768 24.55 ± 0.07
llama 8B BF16 14.96 GiB 8.03 B CPU 16 1 tg16 @ d32768 2.60 ± 0.01

the FA is slower than with BF16 KV and no FA.

and it can be better (possibly with the ZENDNN but did not find time to test with it.) this what I get with a AOCL backend I create for some POC:

model size params backend threads type_k type_v test t/s
llama 8B BF16 14.96 GiB 8.03 B AOCL 16 bf16 bf16 pp512 435.10 ± 0.19
llama 8B BF16 14.96 GiB 8.03 B AOCL 16 bf16 bf16 tg16 7.65 ± 0.00
llama 8B BF16 14.96 GiB 8.03 B AOCL 16 bf16 bf16 pp512 @ d512 412.74 ± 0.05
llama 8B BF16 14.96 GiB 8.03 B AOCL 16 bf16 bf16 tg16 @ d512 7.58 ± 0.00
llama 8B BF16 14.96 GiB 8.03 B AOCL 16 bf16 bf16 pp512 @ d1024 390.77 ± 6.90
llama 8B BF16 14.96 GiB 8.03 B AOCL 16 bf16 bf16 tg16 @ d1024 7.51 ± 0.00
llama 8B BF16 14.96 GiB 8.03 B AOCL 16 bf16 bf16 pp512 @ d2048 350.18 ± 3.76
llama 8B BF16 14.96 GiB 8.03 B AOCL 16 bf16 bf16 tg16 @ d2048 7.37 ± 0.00
llama 8B BF16 14.96 GiB 8.03 B AOCL 16 bf16 bf16 pp512 @ d4096 296.25 ± 1.32
llama 8B BF16 14.96 GiB 8.03 B AOCL 16 bf16 bf16 tg16 @ d4096 7.10 ± 0.00
llama 8B BF16 14.96 GiB 8.03 B AOCL 16 bf16 bf16 pp512 @ d8192 236.90 ± 1.19
llama 8B BF16 14.96 GiB 8.03 B AOCL 16 bf16 bf16 tg16 @ d8192 6.66 ± 0.00
llama 8B BF16 14.96 GiB 8.03 B AOCL 16 bf16 bf16 pp512 @ d16384 164.02 ± 1.02
llama 8B BF16 14.96 GiB 8.03 B AOCL 16 bf16 bf16 tg16 @ d16384 5.94 ± 0.00
llama 8B BF16 14.96 GiB 8.03 B AOCL 16 bf16 bf16 pp512 @ d32768 100.90 ± 0.24
llama 8B BF16 14.96 GiB 8.03 B AOCL 16 bf16 bf16 tg16 @ d32768 4.86 ± 0.00

@am17an
Copy link
Contributor Author

am17an commented Feb 3, 2026

@Djip007 at this point optimised FA path is only used for kv f16 or f32. Theoretically it can be extended for native bf16, but I don't have access to hardware for that.

@Djip007
Copy link
Contributor

Djip007 commented Feb 3, 2026

This one is from a 16 core ZEN3 , full FP16.

~/LLM$ OMP_NUM_THREADS=16 GOMP_CPU_AFFINITY="0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15" \
time ./build_ref/cpu/bin/llama-bench -ngl 0 -ub 512 \
    -fa "0,1" \
    -r 2 -n 16 -p 512 \
    -d "0,512,1024,2048,4096,8192,16384,32768" \
    -m Meta-Llama-3.1-8B-Instruct/F16.gguf
model size params backend threads fa test t/s
llama 8B F16 14.96 GiB 8.03 B CPU 16 0 pp512 112.23 ± 0.08
llama 8B F16 14.96 GiB 8.03 B CPU 16 0 tg16 3.40 ± 0.00
llama 8B F16 14.96 GiB 8.03 B CPU 16 0 pp512 @ d512 106.96 ± 0.08
llama 8B F16 14.96 GiB 8.03 B CPU 16 0 tg16 @ d512 3.37 ± 0.00
llama 8B F16 14.96 GiB 8.03 B CPU 16 0 pp512 @ d1024 101.99 ± 0.18
llama 8B F16 14.96 GiB 8.03 B CPU 16 0 tg16 @ d1024 3.35 ± 0.00
llama 8B F16 14.96 GiB 8.03 B CPU 16 0 pp512 @ d2048 94.02 ± 0.31
llama 8B F16 14.96 GiB 8.03 B CPU 16 0 tg16 @ d2048 3.31 ± 0.00
llama 8B F16 14.96 GiB 8.03 B CPU 16 0 pp512 @ d4096 81.65 ± 0.18
llama 8B F16 14.96 GiB 8.03 B CPU 16 0 tg16 @ d4096 3.21 ± 0.01
llama 8B F16 14.96 GiB 8.03 B CPU 16 0 pp512 @ d8192 64.40 ± 0.13
llama 8B F16 14.96 GiB 8.03 B CPU 16 0 tg16 @ d8192 3.04 ± 0.00
llama 8B F16 14.96 GiB 8.03 B CPU 16 0 pp512 @ d16384 44.39 ± 0.15
llama 8B F16 14.96 GiB 8.03 B CPU 16 0 tg16 @ d16384 2.72 ± 0.03
llama 8B F16 14.96 GiB 8.03 B CPU 16 0 pp512 @ d32768 27.45 ± 0.09
llama 8B F16 14.96 GiB 8.03 B CPU 16 0 tg16 @ d32768 2.24 ± 0.00
llama 8B F16 14.96 GiB 8.03 B CPU 16 1 pp512 112.03 ± 0.11
llama 8B F16 14.96 GiB 8.03 B CPU 16 1 tg16 3.42 ± 0.00
llama 8B F16 14.96 GiB 8.03 B CPU 16 1 pp512 @ d512 102.40 ± 0.01
llama 8B F16 14.96 GiB 8.03 B CPU 16 1 tg16 @ d512 3.38 ± 0.00
llama 8B F16 14.96 GiB 8.03 B CPU 16 1 pp512 @ d1024 94.67 ± 0.07
llama 8B F16 14.96 GiB 8.03 B CPU 16 1 tg16 @ d1024 3.36 ± 0.00
llama 8B F16 14.96 GiB 8.03 B CPU 16 1 pp512 @ d2048 80.21 ± 0.14
llama 8B F16 14.96 GiB 8.03 B CPU 16 1 tg16 @ d2048 3.17 ± 0.15
llama 8B F16 14.96 GiB 8.03 B CPU 16 1 pp512 @ d4096 65.21 ± 0.00
llama 8B F16 14.96 GiB 8.03 B CPU 16 1 tg16 @ d4096 3.21 ± 0.00
llama 8B F16 14.96 GiB 8.03 B CPU 16 1 pp512 @ d8192 48.01 ± 0.00
llama 8B F16 14.96 GiB 8.03 B CPU 16 1 tg16 @ d8192 2.98 ± 0.01
llama 8B F16 14.96 GiB 8.03 B CPU 16 1 pp512 @ d16384 28.64 ± 0.01
llama 8B F16 14.96 GiB 8.03 B CPU 16 1 tg16 @ d16384 2.65 ± 0.00
llama 8B F16 14.96 GiB 8.03 B CPU 16 1 pp512 @ d32768 15.50 ± 0.00
llama 8B F16 14.96 GiB 8.03 B CPU 16 1 tg16 @ d32768 2.03 ± 0.00
  • with fp32 model and kv: using zenDNN backend.
model size params backend threads type_k type_v fa test t/s
llama 8B all F32 29.92 GiB 8.03 B ZenDNN 16 f32 f32 0 pp512 98.63 ± 0.05
llama 8B all F32 29.92 GiB 8.03 B ZenDNN 16 f32 f32 0 tg16 1.61 ± 0.00
llama 8B all F32 29.92 GiB 8.03 B ZenDNN 16 f32 f32 0 pp512 @ d512 93.18 ± 0.04
llama 8B all F32 29.92 GiB 8.03 B ZenDNN 16 f32 f32 0 tg16 @ d512 1.60 ± 0.00
llama 8B all F32 29.92 GiB 8.03 B ZenDNN 16 f32 f32 0 pp512 @ d1024 87.97 ± 0.06
llama 8B all F32 29.92 GiB 8.03 B ZenDNN 16 f32 f32 0 tg16 @ d1024 1.59 ± 0.00
llama 8B all F32 29.92 GiB 8.03 B ZenDNN 16 f32 f32 0 pp512 @ d2048 79.93 ± 0.04
llama 8B all F32 29.92 GiB 8.03 B ZenDNN 16 f32 f32 0 tg16 @ d2048 1.57 ± 0.00
llama 8B all F32 29.92 GiB 8.03 B ZenDNN 16 f32 f32 0 pp512 @ d4096 67.59 ± 0.06
llama 8B all F32 29.92 GiB 8.03 B ZenDNN 16 f32 f32 0 tg16 @ d4096 1.54 ± 0.00
llama 8B all F32 29.92 GiB 8.03 B ZenDNN 16 f32 f32 0 pp512 @ d8192 51.85 ± 0.01
llama 8B all F32 29.92 GiB 8.03 B ZenDNN 16 f32 f32 0 tg16 @ d8192 1.46 ± 0.00
llama 8B all F32 29.92 GiB 8.03 B ZenDNN 16 f32 f32 0 pp512 @ d16384 34.69 ± 0.01
llama 8B all F32 29.92 GiB 8.03 B ZenDNN 16 f32 f32 0 tg16 @ d16384 1.33 ± 0.00
llama 8B all F32 29.92 GiB 8.03 B ZenDNN 16 f32 f32 0 pp512 @ d32768 19.74 ± 0.06
llama 8B all F32 29.92 GiB 8.03 B ZenDNN 16 f32 f32 0 tg16 @ d32768 1.11 ± 0.00
llama 8B all F32 29.92 GiB 8.03 B ZenDNN 16 f32 f32 1 pp512 99.05 ± 0.03
llama 8B all F32 29.92 GiB 8.03 B ZenDNN 16 f32 f32 1 tg16 1.61 ± 0.00
llama 8B all F32 29.92 GiB 8.03 B ZenDNN 16 f32 f32 1 pp512 @ d512 91.58 ± 0.02
llama 8B all F32 29.92 GiB 8.03 B ZenDNN 16 f32 f32 1 tg16 @ d512 1.58 ± 0.00
llama 8B all F32 29.92 GiB 8.03 B ZenDNN 16 f32 f32 1 pp512 @ d1024 85.32 ± 0.16
llama 8B all F32 29.92 GiB 8.03 B ZenDNN 16 f32 f32 1 tg16 @ d1024 1.56 ± 0.00
llama 8B all F32 29.92 GiB 8.03 B ZenDNN 16 f32 f32 1 pp512 @ d2048 75.05 ± 0.01
llama 8B all F32 29.92 GiB 8.03 B ZenDNN 16 f32 f32 1 tg16 @ d2048 1.52 ± 0.00
llama 8B all F32 29.92 GiB 8.03 B ZenDNN 16 f32 f32 1 pp512 @ d4096 62.45 ± 0.01
llama 8B all F32 29.92 GiB 8.03 B ZenDNN 16 f32 f32 1 tg16 @ d4096 1.44 ± 0.00
llama 8B all F32 29.92 GiB 8.03 B ZenDNN 16 f32 f32 1 pp512 @ d8192 45.68 ± 0.02
llama 8B all F32 29.92 GiB 8.03 B ZenDNN 16 f32 f32 1 tg16 @ d8192 1.30 ± 0.00
llama 8B all F32 29.92 GiB 8.03 B ZenDNN 16 f32 f32 1 pp512 @ d16384 28.04 ± 0.44
llama 8B all F32 29.92 GiB 8.03 B ZenDNN 16 f32 f32 1 tg16 @ d16384 1.07 ± 0.00
llama 8B all F32 29.92 GiB 8.03 B ZenDNN 16 f32 f32 1 pp512 @ d32768 16.76 ± 0.01
llama 8B all F32 29.92 GiB 8.03 B ZenDNN 16 f32 f32 1 tg16 @ d32768 0.77 ± 0.00

@am17an
Copy link
Contributor Author

am17an commented Feb 4, 2026

The FA path is still memory bound from what you posted, so maybe 32 or 64 threads (or faster mem bandwidth) will have the benefit over the normal path - apart from reduced memory requirements.

The main problem is that the inner GEMM to calculate the tiles is not optimised (yet). I'm not able to find something I can use off the shelf for a single threaded tiny GEMM. tinyBLAS seems to work in threadgroups. Maybe something exists already but I'm not sure (cc @ggerganov) Something like Q_TILE x head_size x KV_TILE. I plan to add this in the future

shaofeiqi pushed a commit to qualcomm/llama.cpp that referenced this pull request Feb 6, 2026
* ggml-cpu: split across kv for faster TG

* simplify sinks application

* add ref impl
@am17an
Copy link
Contributor Author

am17an commented Feb 15, 2026

@Djip007 I just merged #19422 which should speed up PP for FA

@Djip007
Copy link
Contributor

Djip007 commented Feb 15, 2026

Nice. I add some comment on the PR.

The best I see is that we can with this GEMM implementation use bf16_dot2 ! good job. 🤞
(did not test it, I need more time ;) )

[edit]: for bf16_dot2 we may need some "repacking"... so may be more complicated, this is not the tinybas dot strategy...

liparetejas pushed a commit to liparetejas/llama.cpp that referenced this pull request Feb 23, 2026
* ggml-cpu: split across kv for faster TG

* simplify sinks application

* add ref impl
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning testing Everything test related

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants