ggml-cpu: FA split across kv for faster TG#19209
Conversation
|
I am a bit worried that the reference FA implementation will get increasingly complex and it's important to be simple enough since this is one of the most evolving ops. If we make an error during these optimizations, we currently don't have a way to verify this. Think we should implement a way to validate that the CPU computation is correct first - for example running against the basic vec-kernel implementation. Have a few different ideas, but still not sure which one is the best. |
|
I ran this through the CUDA tests to catch bugs, since the CUDA implementation can be considered correct in this case. |
|
@ggerganov if the CI passes, does it not mean the reference can be updated to the new implementation? As basically we are saying all backends agree on this new reference (which matches the old reference). Or do you want to test things which are not covered in the CI? |
|
Long-term I think it would be useful to add something like |
|
I can add that in this PR if it helps. What I think would be the most useful is comparing the non flash attn output to the flash attn output on CPU. That way we are free to add optimizations to the flash attention path if it matches with the non flash attn path. |
|
There are some features though like KV cache quantization that as of right now strictly require FA - but if that is the only thing we're missing I think it would be fine. Can you check whether there is anything else that we could not cover like that? |
I think that implementing the non-FA path will become very complicated. The FA reference should be the FA vec implementation as it is very simple and all changes to the FA operator in the future should start with modifying that implementation. We just have to figure out some refactoring that would allow to execute this reference code in the test program - I am not sure what is the best way to do it. |
|
What about an env variable which forces the vec kernel and then a logprobs test between this and the normal path? EDIT: I think log probs test is not comprehensive enough, ideally |
|
With test-backend-ops -o FLASH_ATTN_EXT -b CPUIt runs two CPU backends against each other. We just have to figure out a way to make one of the runs to use the vec kernel. For example, we can:
Either of those would work, though neither seems perfect. |
|
The op_params I think it's non-ideal but works. Another idea I have is adding a EDIT: I don't think that's possible at the moment since we only get |
|
Yes, a flag to the cpu backend seems good. Though it would require a bit of plumbing. Roughly, something like this:
In llama.cpp/tests/test-backend-ops.cpp Lines 8587 to 8593 in 1239267 |
|
I added something like this in 7adc22d, let me know if looks ok and I will push it to this branch |
|
Looks good - added a few comments to the commit |
efe83e1 to
88c5fa6
Compare
|
We just need to update the |
|
What is the perf off this OP vs no FA with bf16 kv on zen4/5 ? and it is a longtime that with BF16/llamafile matmul kernel that the KV is a lot slower than without on a IA-MAX+ 395 (ie 16 core zen5 I get:
the FA is slower than with BF16 KV and no FA. and it can be better (possibly with the ZENDNN but did not find time to test with it.) this what I get with a AOCL backend I create for some POC:
|
|
@Djip007 at this point optimised FA path is only used for kv f16 or f32. Theoretically it can be extended for native bf16, but I don't have access to hardware for that. |
|
This one is from a 16 core ZEN3 , full FP16. ~/LLM$ OMP_NUM_THREADS=16 GOMP_CPU_AFFINITY="0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15" \
time ./build_ref/cpu/bin/llama-bench -ngl 0 -ub 512 \
-fa "0,1" \
-r 2 -n 16 -p 512 \
-d "0,512,1024,2048,4096,8192,16384,32768" \
-m Meta-Llama-3.1-8B-Instruct/F16.gguf
|
|
The FA path is still memory bound from what you posted, so maybe 32 or 64 threads (or faster mem bandwidth) will have the benefit over the normal path - apart from reduced memory requirements. The main problem is that the inner GEMM to calculate the tiles is not optimised (yet). I'm not able to find something I can use off the shelf for a single threaded tiny GEMM. tinyBLAS seems to work in threadgroups. Maybe something exists already but I'm not sure (cc @ggerganov) Something like Q_TILE x head_size x KV_TILE. I plan to add this in the future |
* ggml-cpu: split across kv for faster TG * simplify sinks application * add ref impl
|
Nice. I add some comment on the PR. The best I see is that we can with this GEMM implementation use bf16_dot2 ! good job. 🤞 [edit]: for bf16_dot2 we may need some "repacking"... so may be more complicated, this is not the tinybas dot strategy... |
* ggml-cpu: split across kv for faster TG * simplify sinks application * add ref impl
Continuing on #19012, In the FA CPU implementation we don't parallelize across the context size, so each thread reads the entire sequence for calculating it's values. This PR introduces chunking across the context size, where each thread maintains partial accumulators for the sum, maximum, and running VKQ sum for the the soft-max calculation and calculates this for all query heads. There is an extra reduction step in the end to reduce the partials. Original idea from https://pytorch.org/blog/flash-decoding/ but then @JohannesGaessler pointed out that the CUDA FA kernel already does this.
Tested on three models and the results are good, larger (>2x) speed-ups as context size grows. Note at lower contexts the results are a bit noisy. Also note that going from 32->64 cores on master doesn't make a difference in some cases because we only parallelize the query heads which can be < n_threads