Skip to content

TurboQuant KV Cache Compression — Working Implementation Ready for Review #1509

@veritatisquaesitoressumus

Description

TurboQuant KV Cache Compression — Working Implementation Ready for Review
Summary
Working implementation of TurboQuant (Zandieh et al., "TurboQuant: Online Vector Quantization for Quantized KV Cache in Large Language Models", ICLR 2026) for KV cache compression in ik_llama.cpp.
What it does: Compresses KV cache from FP16 to 3 bits per value with 4.9x compression and near-zero accuracy loss (MSE 0.034, matching the paper within 1%).
Why it matters: On a 3× RTX 3090 setup (72GB VRAM), a 70B model at Q4_K_M uses ~38-48GB for weights, leaving 24-34GB for KV cache. With FP16 KV, that caps context at ~70-109K tokens. With TQ3, the same VRAM holds ~347-536K tokens — full 262K native context windows fit entirely in VRAM with room to spare.
Status: CPU implementation complete, 18/18 tests passing. CUDA kernels written, awaiting GPU validation. 6-phase integration spec complete. No existing implementations in any llama.cpp fork as of this date.

Validated Test Results

[Test 3] Quantize/Dequantize Round-trip MSE
  b=3: Avg MSE = 0.034144  (paper: ~0.034)  ratio = 1.00
  b=4: Avg MSE = 0.009253  (paper: ~0.009)  ratio = 0.99

[Test 5] Norm Preservation
  Original norm: 3.7000  Reconstructed norm: 3.7118

[Test 6] Compression Ratios
  TQ3: 52 bytes per 128-value vector vs 256 bytes FP16 = 4.9x
  TQ4: 68 bytes per 128-value vector vs 256 bytes FP16 = 3.8x

[Test 8] Speed Benchmark (10000 vectors)
  Quantize:   180.0 ms  (55,556 vectors/sec)  [CPU only]
  Dequantize: 160.0 ms  (62,500 vectors/sec)  [CPU only]

Results: 18 passed, 0 failed

Algorithm Overview
TurboQuant quantizes each head-dimension vector (d=128) independently:
Quantize (KV cache write):
Store ||x|| as float32 norm (4 bytes)
Normalize: x_unit = x / ||x||
Rotate: y = Π · x_unit (Π is a fixed random orthogonal matrix, generated once at cache init with deterministic seed)
For each y[j], find nearest Lloyd-Max codebook centroid index (3 bits for TQ3, 4 bits for TQ4)
Bit-pack 128 indices into 48 bytes (TQ3) or 64 bytes (TQ4)
TQ3 total: 52 bytes per vector (4 norm + 48 packed indices)
TQ4 total: 68 bytes per vector (4 norm + 64 packed indices)
Dequantize (KV cache read / attention):
Unpack indices
Map indices → codebook centroids
Rotate back: x_hat = Π^T · y_hat
Scale by stored norm
Why it works:
Random rotation makes all coordinates of unit vectors follow the same Beta distribution (Lemma 1 in paper). The Lloyd-Max codebook is MSE-optimal for that distribution. Since the rotation is orthogonal, MSE is preserved in the original space. The paper proves this achieves near-optimal distortion rate (Theorem 1).

Memory Layout

// TQ3: 52 bytes per 128-value block
typedef struct {
    float norm;                    // 4 bytes: vector norm
    uint8_t packed_indices[48];    // 48 bytes: 128 × 3-bit indices, bit-packed
} block_tq3;                       // Total: 52 bytes (vs 256 bytes FP16)

// TQ4: 68 bytes per 128-value block
typedef struct {
    float norm;                    // 4 bytes: vector norm
    uint8_t packed_indices[64];    // 64 bytes: 128 × 4-bit indices, bit-packed
} block_tq4;                       // Total: 68 bytes (vs 256 bytes FP16)

Block size is 128 values — maps directly to attention head dimension (n_embd_head_k = 128 for most 70B models including Llama 3.x and Qwen2.5).

Projected Impact
70B model (Q4_K_M, ~38GB weights) on 72GB VRAM (34GB free for KV):
KV Type Bytes/token Max context in 34GB
FP16 327,680 ~109K tokens
Q8_0 163,840 ~218K tokens
TQ3 66,560 ~536K tokens
TQ4 87,040 ~410K tokens
72B VL model (Q4_K_M, ~48GB weights + 1.35GB mmproj) on 72GB VRAM (~22GB free for KV):
KV Type Bytes/token Max context in 22GB
FP16 327,680 ~70K tokens
Q8_0 163,840 ~141K tokens
TQ3 66,560 ~347K tokens
TQ4 87,040 ~265K tokens

Integration Plan (6 Phases)
Phase 1: New GGML Type Registration
Where: ggml/include/ggml.h
Add GGML_TYPE_TQ3 and GGML_TYPE_TQ4 to the type enum. Register in ggml_type_size, ggml_type_name, ggml_blck_size tables. Block size = 128 values.
Phase 2: KV Cache Type Support
Where: src/llama-kv-cache.cpp
Add TQ3/TQ4 to allowed cache-type-k and cache-type-v values.
Usage: --cache-type-k tq3 --cache-type-v tq3
Phase 3: Quantize on KV Write
Where: KV cache write path in src/llama-kv-cache.cpp
When cache type is TQ3/TQ4, intercept the FP16/FP32 K or V vector before cache write, call tq_quantize(), store the block_tq3 struct in the cache buffer.
Phase 4: Dequantize on KV Read (Flash Attention)
Where: ggml/src/ggml-cuda/fattn-*.cu
Two approaches:
4a. Non-fused (initial, zero-risk): Dequantize TQ3 blocks back to FP16 before flash attention runs. FA kernels unchanged.
4b. Fused (optimization): Compute Q · dequant(K) without materializing the dequantized vector. Uses pre-rotated queries: Q_rot = Q · Π^T, then dots against codebook values directly. Requires modifying FA kernels but avoids the dequant memory allocation.
Phase 5: Rotation Matrix Lifecycle
One tq_context per KV cache (K and V each). Initialized once at model load with deterministic seed. The rotation matrix is 64 KB in GPU global memory — negligible.
Phase 6: CLI Flags
Where: common/arg.cpp
Add tq3 and tq4 to allowed values for --cache-type-k and --cache-type-v.

Files
The implementation consists of:
File Purpose Status
ggml_turboquant.h Header: types, codebooks, API Complete
ggml_turboquant.c CPU: quantize, dequant, rotation, bit-pack Complete, 18/18 tests
ggml_turboquant.cu CUDA: GPU quantize/dequant + fused attention dot Complete, needs GPU test
tq_test.c Test harness Complete
turboquant_codebooks.json Pre-computed Lloyd-Max codebooks for d=128 Complete
Full source: https://gist.github.com/veritatisquaesitoressumus/6aa5973955007ffd858889c76aa60408

What's NOT included
QJL error correction (Algorithm 2 from paper): Paper shows TQ_mse alone (Algorithm 1) is sufficient for KV cache compression. QJL adds inner-product unbiasedness but costs an extra bit per coordinate. Not needed for the KV cache use case.
Actual ik_llama.cpp source modifications: This is the standalone implementation + integration spec. Intended for review before applying to the codebase.

Build (test harness only)

gcc -O2 -o tq_test ggml_turboquant.c tq_test.c -lm
./tq_test
# Expected: 18 passed, 0 failed

References
Zandieh, Amir, et al. "TurboQuant: Online Vector Quantization for Quantized KV Cache in Large Language Models." ICLR 2026. arXiv:2504.19874
Related mainline discussion: ggml-org/llama.cpp#20969

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions