Skip to content

cksac/turboquant-model

Repository files navigation

TurboQuant Model

Near-optimal weight quantization with on-the-fly dequantization for LLM inference.

Based on: TurboQuant: Online Vector Quantization with Near-optimal Distortion Rate (Zandieh et al., 2025)

Site: https://cksac.github.io/turboquant-model/

Features

  • 4-bit weight quantization with near-optimal MSE distortion (within 2.7x of information-theoretic lower bound)
  • Residual quantization for fine-grained bit allocation (e.g., 4+4=8 bits, 3+2=5 bits)
  • On-the-fly dequantization — weights stay packed as 4-bit indices, dequantized during matmul
  • 3.2x GPU memory savings vs bf16 with only 27% latency overhead
  • Drop-in replacement for nn.Linear — no model architecture changes needed
  • Save/load quantized models to disk

Installation

uv pip install -e ".[transformers]"

Quick Start

Python API

from transformers import AutoModelForCausalLM, AutoTokenizer
from turboquant_model import TurboQuantConfig, quantize_model

# Load model
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3.5-0.8B-Base", dtype=torch.bfloat16).cuda()
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3.5-0.8B-Base")

# Quantize (single-pass 4-bit)
config = TurboQuantConfig(bit_width=4, seed=42)
model = quantize_model(model, config)
# → 187 layers quantized, 1434 MB → 361 MB (4.0x compression)

# Residual quantization (4+4 = 8 total bits, near-lossless)
config = TurboQuantConfig(bit_width=4, residual_bit_width=4, seed=42)
model = quantize_model(model, config)

# Save / Load
from turboquant_model import save_quantized, load_quantized
save_quantized(model, config, "./quantized-model")
model = load_quantized("Qwen/Qwen3.5-0.8B-Base", "./quantized-model")

CLI

# Quantize and save
turboquant quantize --model Qwen/Qwen3.5-0.8B-Base --output ./quantized --bit-width 4

# Residual quantization
turboquant quantize --model Qwen/Qwen3.5-0.8B-Base --output ./quantized \
    --bit-width 4 --residual-bit-width 4

# Evaluate PPL
turboquant eval --model Qwen/Qwen3.5-0.8B-Base --quantized ./quantized

# Generate text
turboquant generate --model Qwen/Qwen3.5-0.8B-Base --quantized ./quantized \
    --prompt "The capital of France is"

# Benchmark
turboquant benchmark --model Qwen/Qwen3.5-0.8B-Base --bit-width 4

How It Works

TurboQuant Algorithm

  1. Row-normalize each weight row to unit norm, store the norm separately
  2. Random rotation (QR decomposition or fast Walsh–Hadamard + random signs) — maps coordinates to near-independent N(0, 1/d)
  3. Lloyd-Max scalar quantization per coordinate — optimal for Gaussian distribution
  4. Pack indices into 4-bit format (2 per byte)

On-the-fly Dequantization

Instead of inverse-rotating the weight (expensive: N×K matrix), we pre-rotate the input (cheap: B×K vector):

x_rot = x @ Pi.T                              # rotate input (once per layer)
output = x_rot @ codebook[indices].T           # fused lookup + matmul  
output = output * (norms / sqrt(group_size))   # rescale

Residual Quantization

Apply TurboQuant twice with different rotation seeds:

Pass 1: W_hat = TQ(W, b1 bits, seed1)
Pass 2: R_hat = TQ(W - W_hat, b2 bits, seed2)  
Final:  W_approx = W_hat + R_hat

Total bits = b1 + b2, but quality is much better than single-pass at same total bits.

Why Not QJL?

The original TurboQuant paper defines TurboQuant_prod — a variant that applies QJL (Quantized Johnson-Lindenstrauss) as a 1-bit correction on the residual to produce an unbiased inner product estimator. We do not use QJL. Here's why:

  1. QJL solves a different problem. It's designed for online inner product estimation (e.g., KV cache attention: quantize keys once, query with many different vectors later). Weight quantization is offline — we compress $W$ once and compute $y = xW^T$ repeatedly. We want minimum reconstruction error $|W - \tilde{W}|$, not an unbiased dot-product estimator.

  2. Unbiasedness is unnecessary for weights. A small deterministic bias from MSE-optimal quantization is absorbed by layer norms, residual connections, and softmax. An unbiased but high-variance estimator (QJL at 1 bit) introduces noise that changes every forward pass — worse for stable inference.

  3. Residual quantization strictly dominates. QJL uses 1 bit (random sign projection) for the residual. Our residual TQ uses b₂ bits with full Lloyd-Max codebook + independent rotation — capturing far more residual information. At 4+4 total bits, residual TQ achieves KL divergence of only 0.002 nats (practically lossless). QJL's 1-bit correction cannot compete.

  4. QJL requires the query at runtime. The QJL correction term depends on the input activation $x$, making it incompatible with offline weight compression. You'd need to recompute corrections per forward pass — defeating the purpose.

Summary: QJL is elegant for streaming/KV-cache inner product preservation. For weight compression, multi-pass residual quantization with optimal scalar codebooks is the natural and superior choice.

Benchmark Results (Qwen3.5-0.8B-Base, WikiText-103 val, 50 chunks)

Config Total Bits Codebook PPL Δ PPL KLD Compressed Size Peak GPU
Baseline bf16 16 14.29 1,504 MB
4+4 residual g=128 8 16+16 (128 B) 14.28 −0.01 0.0020 762 MB 9.7 GB
4+2 residual g=128 6 16+4 (80 B) 14.46 +0.17 0.0159 762 MB 9.7 GB
3+2 residual g=128 5 8+4 (48 B) 15.15 +0.86 0.0545 762 MB 9.7 GB
4-bit g=full 4 16 (64 B) 16.22 +1.93 0.1363 361 MB 9.6 GB
4-bit g=128 4 16 (64 B) 16.58 +2.29 0.1403 381 MB 5.8 GB

Codebook = Lloyd-Max optimal centroids for N(0,1). Size = 2^b × 4 bytes (float32). Shared globally across all layers — negligible overhead.

Key findings:

  • 4+4 residual is near-lossless — PPL 14.28 ≈ baseline 14.29, KLD only 0.002 nats
  • 4-bit g=128 fits on 8 GB GPUs (5.8 GB peak) with only 2.3 PPL degradation (KLD 0.14)
  • Smaller group sizes (g=128) use much less GPU memory due to smaller rotation matrices

Qwen3.5-4B

Config Total Bits PPL Δ PPL KLD
Baseline bf16 16 10.67
4+4 residual g=128 8 10.70 +0.03 0.0028
4+2 residual g=128 6 10.65 −0.02 0.0133
4-bit g=128 4 11.28 +0.61 0.0852

Fused Kernel Benchmarks

Fused kernels (CuTile, Triton) combine 4-bit unpack + codebook lookup + matmul + norm rescale in a single kernel launch, avoiding intermediate tensor materialization. Auto-enabled when available (priority: CuTile > Triton > PyTorch fallback).

Qwen3.5-0.8B-Base (4-bit g=128)

Path Latency (ms/fwd) Peak GPU (MB) Speedup Memory Reduction
CuTile (fused) 340 1,086 1.10x 4.5x
Triton (fused) 386 1,334 0.97x 3.7x
PyTorch (fallback) 373 4,883 1.0x

Qwen3.5-4B (4-bit g=128)

Path Latency (ms/fwd) Peak GPU (MB) Speedup Memory Reduction
CuTile (fused) 968 3,954 3.98x 5.7x
Triton (fused) 1,098 4,119 3.51x 5.4x
PyTorch (fallback) 3,855 22,377 1.0x

Both fused kernels provide massive memory savings by never materializing the (N, K) float32 codebook[indices] tensor. CuTile edges out Triton in both latency and memory. Speedup scales dramatically with model size (1.1x → 4.0x). Disable per-module with m.use_cutile = False or m.use_triton = False.

Rotation Method Comparison (Qwen3.5-0.8B-Base, g=128)

Two rotation methods: QR (Haar-distributed random orthogonal, O(d²) storage/compute) and Hadamard (fast Walsh–Hadamard + random signs, O(d) storage, O(d log d) compute).

Config QR PPL QR KLD Hadamard PPL Hadamard KLD
4+4 residual 14.28 0.0020 14.30 0.0020
4+2 residual 14.46 0.0159 14.49 0.0148
4-bit 16.58 0.1403 16.35 0.1394

Hadamard matches QR quality across all configs while using O(d) vs O(d²) storage. Use --rotation hadamard to enable.

Architecture

turboquant_model/
├── codebook.py          # Lloyd-Max optimal codebook (precomputed)
├── rotation.py          # Random rotation: QR (Haar) or fast Walsh-Hadamard + signs
├── quantize.py          # Single-pass quantize + pack/unpack
├── residual.py          # Residual (two-pass) quantization
├── cutile_kernels.py    # Fused CuTile matmul kernel (optional)
├── triton_kernels.py    # Fused Triton matmul kernel (optional)
├── module.py            # TurboQuantLinear (nn.Module)
├── model.py             # quantize_model, save/load, config
└── cli.py               # Command-line interface

License

MIT

Releases

No releases published

Packages

 
 
 

Contributors