perf : study batched decoding bottleneck#3726
Conversation
Check the GPU utilization, especially over time with NSight Systems. Multiple CUDA streams can help with better utilization but if the utilization is not the problem then it won't do much.
For FP16 KV cache I don't think this will help. Writing efficient GEMM kernels is very hard and I very much do not expect that custom FP16 GEMM kernels will be able to outperform cuBLAS. I think cuBLAS had functionality for batched GEMM, so maybe using that would make more sense? |
|
I've implemented a batched CUBLAS GEMM version and I observe significant improvements across the board. Will try to post updated numbers for the tests above a bit later. Edit: see #3749 for A100 numbers and proposed PR |
ref #3479
Description
I wanted to find out what is currently
llama.cpp's main limitation when running batched decoding, so I ran a few tests on different hardware to profile mainly the self attention overhead when using the existing unified KV cache implementation (#3228).Below are the results on 3 different hardware:
I'm using the batched-bench tool to run PP + TG for different number of batches:
This PR adds a hack to allow for conveniently turning on and off some of the attention ops via environment variables:
SKIP_KQ_KQV=1to skip the 2 matrix multiplicationsKQandKQVSKIP_KQ_ALL=1to skip all attention ops (KQ,KQ_scaled,KQ_masked,KQ_soft_max,KQV)masterI've also performed 2 custom diffs for Metal and CUDA to run the full computation but force only 1 KV head to be computed during matrix-multiplication:
CUDA diff to force 1 KV head
Metal diff to force 1 KV head
All these options allow us to measure the overhead from the following computations individually:
KQandKQVmatrix multiplications for all headsKQandKQVmatrix multiplications per attention headKQ_scale+KQ_masked+KQ_soft_maxResults
These are the raw numbers that I measured. In each file, first are the 7B runs, followed by the 1B runs:
Here I'll inline part of the A100 results for convenience. For the rest of the results, checkout the text files above:
normal
LLAMA_CUBLAS=1 make -j && ./batched-bench /workspace/openllama-7b/ggml-model-f16.gguf 4608 1 99 0 512 128 1,2,3,4,5,6,7,8,16,32SKIP_KQ_ALL=1
LLAMA_CUBLAS=1 make -j && SKIP_KQ_ALL=1 ./batched-bench /workspace/openllama-7b/ggml-model-f16.gguf 4608 1 99 0 512 128 1,2,3,4,5,6,7,8,16,32normal + force 1 KV head
LLAMA_CUBLAS=1 make -j && ./batched-bench /workspace/openllama-7b/ggml-model-f16.gguf 4608 1 99 0 512 128 1,2,3,4,5,6,7,8,16,32Observations
1.13xfor 1B and1.09xfor 7B2.6x1B,1.8x7B) for RTX 4080 and (3.7x1B,4.3x7B)KQandKQVmatrix multiplications on CUDA take a much larger toll compared to Metal, both for prompt processing and for more than 1 batchesKQandKQVprocessing time scales linearly with the number of KV heads for more than 1 batch, while on Metal where we have a custom matrix-matrix multiplication kernel, the computation scales much betterIf my analysis above are correct, there is a significant speedup to be gained for CUDA - both for batched decoding and for prompt processing. I'm not familiar with the best practices for CUDA, but I think we should either:
n_headCUBLAS GEMMs in a single CUDA stream). If I remember correctly, we have tried utilizing CUDA streams, but only for single-batch decoding. Probably we have to revisit?KQandKQVops wherene02 > 1andne12 > 1These observations could also explain the poor performance observed for speculative decoding on A100 reported here: #3649
Reproducing
If anyone is interested in re-running the CUDA tests above, I used the following script:
Bash script for getting data and running llama.cpp
On runpod, the above RTX 4080 and A100 tests cost me a total of ~$1.11 to perform. You would need ~40 GB storage.
Alternatively, you can run them locally - the tests require 16GB VRAM
cc @slaren @JohannesGaessler for any comments and insights