You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
docs: clarify TurboQuant hybrid architecture in README
- Updates the TurboQuantization section in README to explain the fusion of V2 speed and V3 quality algorithms
- Adds 'docs/turboquant_hybrid_architecture.md' with deep-dive technical analysis of the Lloyd-Max + QJL Metal integration
Copy file name to clipboardExpand all lines: README.md
+17-17Lines changed: 17 additions & 17 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -17,29 +17,29 @@ No Python runtime, no Global Interpreter Lock (GIL), no unnecessary memory copie
17
17
18
18
## ⚡️ TurboQuantization: KV Cache Compression
19
19
20
-
`SwiftLM` implements **TurboQuant**(AISTATS/ICLR 2026) for on-the-fly KV cache compression, enabling long-context inference with drastically reduced memory. At 3 bits/coordinate, the KV cache is compressed ~5.8× vs FP16 with near-zero accuracy loss.
20
+
`SwiftLM` implements a **hybrid V2+V3 TurboQuant architecture** for on-the-fly KV cache compression. At roughly ~3.6 bits per coordinate overall, the KV cache is compressed ~3.5× vs FP16 with near-zero accuracy loss.
21
21
22
-
The algorithm runs in two stages per KV vector:
22
+
### By combining V2 Speed with V3 Quality:
23
+
Recent reproductions of the TurboQuant algorithm (e.g., `turboquant-mlx`) revealed two distinct paths:
24
+
1.**V2 (Hardware-Accelerated)**: Fast, but uses linear affine quantization which degrades quality at 3-bit.
25
+
2.**V3 (Paper-Correct)**: Excellent quality using non-linear Lloyd-Max codebooks, but painfully slow due to software dequantization.
23
26
24
-
**Stage 1 — PolarQuant (2 bits):**
25
-
1. Extract L2 norm: `‖x‖`
26
-
2. Normalize: `x̂ = x / ‖x‖`
27
-
3. Rotate: `y = R @ x̂` (random orthogonal R via Fast Walsh-Hadamard Transform — O(d log d))
28
-
4. Quantize each coordinate to nearest Lloyd-Max centroid (optimal for post-rotation Gaussian distribution)
29
-
- → Store: `(2-bit indices[d], float16 norm)`
27
+
**We built the "Holy Grail" hybrid:** We ported the V3 non-linear Lloyd-Max codebooks directly into the native C++ encoding path, and process the dequantization natively in fused Metal (`bggml-metal`) shaders. This achieves **V3 quality at V2 speeds**, completely detached from Python overhead.
30
28
31
-
**Stage 2 — QJL residual (1 bit):**
32
-
1. Dequantize Stage 1 → `x̂_mse`
33
-
2. Compute residual: `r = x - x̂_mse`
34
-
3. Project: `z = S @ r` (S ~ N(0,1) random matrix)
2. Apply Fast Walsh-Hadamard Transform (WHT) rotation to distribute outliers evenly.
34
+
3. Quantize each coordinate using **3-bit non-linear Lloyd-Max centroids**.
35
+
4. Compute the residual error between the original vector and the quantized approximation.
36
+
5. Project the residual via a random Johnson-Lindenstrauss (QJL) matrix and store the 1-bit signs.
37
+
*(Why QJL? QJL acts as an additional regularizer that prevents centroid resolution loss from degrading the attention dot-product.)*
39
38
40
-
> *K cache uses full TurboQuant (Stage 1 + Stage 2) to preserve attention dot-product accuracy. V cache uses Stage 1 only (PolarQuant MSE) since MSE-optimal reconstruction doesn't need the QJL residual stage.*
39
+
**V-Cache (3-bit PolarQuant) = 3.125 bits/dim**
40
+
Because the V-cache matrix is not used for inner-product attention scoring, the QJL error correction provides no benefit. We cleanly disable QJL for the V-cache, extracting an additional 25% memory savings without sacrificing quality.
# TurboQuant Hybrid: Achieving V3 Quality at V2 Speeds in Apple Metal
2
+
> *An architectural analysis for SwiftLM's KV Cache pipeline*
3
+
4
+
KV Cache quantization is fundamentally constrained by a tradeoff between **per-bit representation quality** and **hardware execution speed**. Following the publication of *TurboQuant (Google, 2025)*, reference implementations across the MLX community generally diverged into two disparate paths: **V2 (speed-oriented)** and **V3 (quality-oriented)**.
5
+
6
+
In `SwiftLM`, we discard this dichotomy by fusing the mathematical precision of V3 directly into the hardware-accelerated pathways of V2 natively in C++ and Metal.
7
+
8
+
## The Problem: The V2 / V3 Divergence
9
+
10
+
Recent implementations (such as `turboquant-mlx`) categorized their quantization strategies into two tiers:
11
+
12
+
-**V2 (Affine / Hardware-Accelerated):**
13
+
This approach leverages native `mx.quantize` and `mx.quantized_matmul` ops. It is blazingly fast (~105% of fp16 throughput for simple quantization, ~78% when doing random rotations). However, it relies on linear/affine scaling. Because WHT-rotated vectors naturally form a Gaussian probability distribution `N(0, 1/sqrt(d))`, linear uniform bins are sub-optimal for the long tails of the distribution. At 3-bits or 2-bits, V2 affine scaling aggressively deteriorates perplexity (+9% to +23% PPL).
14
+
-**V3 (Lloyd-Max Codebook / Paper-Correct):**
15
+
This route uses paper-correct non-linear quantization. By using pre-computed Lloyd-Max centroids designed for a Gaussian distribution, the quantization tightly clusters near the dense center and sparsely tracks the tails. This provides near-lossless compression (e.g., +0.3% PPL at 3.5-bit). However, this method requires software dequantization (centroid payload lookups), destroying throughput. On MLX without custom Metal kernels, V3 runs 5-6x slower than V2.
16
+
17
+
## The Solution: A Fused C++/Metal Hybrid Approach
18
+
19
+
Rather than choosing between Python orchestration speed penalties or affine centroid quality loss, `SwiftLM` bypasses the Python boundary entirely. We ported the non-linear Lloyd-Max logic down to the bare metal.
20
+
21
+
### 1. Vector Quantization (C++ Encoding)
22
+
When tokens enter the KV cache during the pre-fill/generation phases, the C++ encoding logic (in `fast_turbo.cpp`) performs the pre-processing natively:
23
+
1.**L2 Normalization**: The vector is scaled to the unit sphere.
24
+
2.**WHT Rotation**: An in-place Fast Walsh-Hadamard Transform `O(d log d)` evenly distributes outlier channels across the dimension array, forcing the payload into an identical Gaussian distribution.
25
+
3.**Lloyd-Max Lookup**: Instead of mathematically calculating linear boundaries, the code uses a binary search across hardcoded probability boundaries (`BOUNDARIES_3BIT`) to assign each item to one of 8 non-linear centroids, packing the result cleanly into `uint8_t` blocks.
26
+
27
+
### 2. Inner-Product Error Correction (QJL)
28
+
The original paper’s "TurboQuant_prod" algorithm attempted to replace 1 bit of MSE payload with 1 bit of Quantized Johnson-Lindenstrauss (QJL) residual estimation. Reference tests overwhelmingly demonstrated that this was a failure on Apple Silicon (softmax exponentially amplified the centroid resolution drop of dropping from 3-bit to 2-bit).
29
+
30
+
Instead, we use QJL strictly as an **additive correction layer**, and **only on the K-Cache**.
31
+
* The **K-Cache** (used for dot-product attention scores) gets 3-bit PolarQuant + 1-bit QJL (`TurboQuantK`). Storage: 4.25 bits/dim.
32
+
* The **V-Cache** (used purely for matrix reconstruction, not attention weighting) is spared the QJL overhead and gets just 3-bit PolarQuant (`TurboQuantV`). Storage: 3.125 bits/dim.
33
+
34
+
### 3. Native Metal Dequantization
35
+
With the heavy lifting done exactly matched to the mathematical shapes of V3, we pass the 16-byte packed structs to the SDPA (Scaled Dot-Product Attention) Metal kernels (`bggml-metal`). The kernel unpacks the 3-bit indices, substitutes them directly from a constant buffer containing `CENTROIDS_3BIT`, and independently executes the 1-bit QJL sign accumulation into the SDPA hot-loop.
36
+
37
+
## Conclusion
38
+
Our hybrid approach guarantees:
39
+
1.**No Python Global Interpreter Lock (GIL) or orchestration overhead**.
40
+
2.**No arbitrary affine quality loss** on Gaussian tails at 3-bit depth.
41
+
3.**Targeted regularization** by isolating QJL to the K-Cache only.
42
+
43
+
The result is a highly efficient unified KV Cache running at an average of **~3.6 bits/dim (~3.5x compression vs fp16)**, recovering the performance characteristics of V2 with the perplexity retention of V3.
0 commit comments