Skip to content

jagmarques/nexusquant

Repository files navigation

NexusQuant

Near-lossless KV cache quantization. Training-free. One line of code.

PyPI License Python Stars CI


E8 lattice quantization applied to the KV cache after prefill. No training, no calibration data, no model modifications.

Headline: K3V2 pb=0 achieves +0.276% PPL at 6.1x compression on Mistral-7B-v0.1 (wikitext-2, n=161 paired chunks). NIAH recall preserved to 32K context on A100. Validated on 7 model architectures.

Install

pip install nexusquant-kv
pip install "nexusquant-kv[hf]"  # with HuggingFace transformers

compress_kv_cache (quant-only API) requires transformers >= 4.46 (oldest version tested). nexusquant_evict requires transformers >= 5.0 and torch >= 2.4 (eviction hooks use DynamicLayer/DynamicSlidingWindowLayer, which are 5.0+ only).

Quickstart

Open In Colab

Quant-only (near-lossless, recommended)

Prefill all but the last token, compress the cache, then let generate consume the final token. Copy and run it.

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache
from nexusquant import compress_kv_cache

model_name = "mistralai/Mistral-7B-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype=torch.float16, device_map="auto"
)
model.eval()

text = "The KV cache stores key and value tensors for each attention layer."
inputs = tokenizer(text, return_tensors="pt").to(model.device)
input_ids = inputs["input_ids"]

cache = DynamicCache()
with torch.no_grad():
    out = model(input_ids[:, :-1], use_cache=True, past_key_values=cache)
past_key_values = out.past_key_values

rope_base = getattr(model.config, "rope_theta", 10000.0)
compressed_kv = compress_kv_cache(
    past_key_values,
    mode="quant_only",
    rope_base=rope_base,
)

with torch.no_grad():
    gen_out = model.generate(
        input_ids,
        past_key_values=compressed_kv,
        max_new_tokens=200,
        do_sample=False,
    )

print(tokenizer.decode(gen_out[0], skip_special_tokens=True))

At 128K context on Mistral-7B (32 layers, 8 KV heads, head_dim=128) this drops the KV cache from ~17.2 GB at FP16 to ~2.8 GB at K3V2 pb=0 (6.1x). The flow above is verified to exit 0 with generated text on CPU using hf-internal-testing/tiny-random-LlamaForCausalLM, nexusquant-kv 0.5.0, torch 2.2.2, transformers 4.46.3, run from /tmp.

With eviction (higher compression)

Requires transformers >= 5.0, torch >= 2.4, and a float16 or bfloat16 model (float32 is not supported for eviction; use compress_kv_cache(mode="quant_only") for float32). nexusquant_evict evicts low-attention tokens and E8-quantizes the survivors in-place during prefill. NIAH recall degrades past 35% eviction (see Limitations).

from nexusquant import nexusquant_evict

with nexusquant_evict(model, quality="balanced") as nq:
    output = model.generate(input_ids, max_new_tokens=512)

After the context exits, nq.last_mask holds the (batch, seq) eviction attention mask; pass it to follow-up generate() calls for multi-turn.

HuggingFace ecosystem

Two on-ramps wire E8 lattice KV quantization (training-free) into the HuggingFace stack.

The standalone cache is a DynamicCache subclass you pass to model.generate. It needs no change to transformers core. from_model reads rope_theta and rope_scaling from the config, so Llama-3.1-family scaling is handled.

from nexusquant.integrations.quantized_cache import E8QuantizedCache

cache = E8QuantizedCache.from_model(model)
output = model.generate(**inputs, past_key_values=cache, max_new_tokens=64)

For the kvpress ecosystem, E8Press is a kvpress.BasePress subclass that runs the E8 quant pass inside kvpress's prefill compress() hook. Install with pip install nexusquant-kv[kvpress].

from nexusquant.integrations.kvpress import E8Press

press = E8Press(key_bits=3, value_bits=2)
with press(model):
    output = model.generate(**inputs, max_new_tokens=64)

Behavior in both: the prefill KV is E8-quantized (K3V2 default). Recent single-token decode writes are kept as an fp16 residual (KIVI-style), so the long-context prefill carries the recall signal while recent tokens stay full precision. That residual is small at long context but not free; count it in any compression-ratio number.

API reference

compress_kv_cache

from nexusquant import compress_kv_cache
compressed_kv = compress_kv_cache(past_key_values, mode="quant_only", **kwargs)
Parameter Type Default Description
past_key_values DynamicCache required From model(..., use_cache=True).past_key_values
mode str "simple" "quant_only" (near-lossless), "simple" (5.3x), "fast" (no RoPE rm), "max" (PCA, needs calibration)
head_dim int 128 Attention head dimension; auto-detected if you use the HF nexusquant context manager
rope_base float 10000.0 RoPE theta from model.config.rope_theta
rope_scaling dict or None None Pass model.config.rope_scaling for Llama-3.1 and similar
bits int 3 Quantization bits (2, 3, or 4)
merge_pct float 0.0 Token merge fraction (0 = disabled); only for "simple" mode

Returns the same past_key_values object, modified in-place with quantized KV tensors.

nexusquant_evict

from nexusquant import nexusquant_evict

with nexusquant_evict(model, quality="balanced") as nq:
    output = model.generate(input_ids, max_new_tokens=200)
Parameter Type Default Description
model HF causal LM required Any model using DynamicCache
quality str or None "balanced" Preset: "high" (K3V2, 35% evict, ~9x), "balanced" (K2V2, 60% evict, ~17x), "max" (K2V2, 80% evict, ~33x), "asym" (K3V2, 60% evict, ~14x). Set None to use raw args.
eviction_rate float or "auto" 0.6 Fraction of prefix tokens to evict. Ignored when a named quality preset is given.
sliding_window int 32 Recent tokens always kept, never evicted
obs_window int 32 Query positions used to score token importance (key-key scorer)
bits int 2 E8 quantization bits for surviving tokens
key_bits int or None None Override bits for keys only (e.g. key_bits=3, bits=2 = K3V2)
value_bits int or None None Override bits for values only
scorer str "key-key" "key-key" (fast, no extra pass) or "real" (softmax weights, needs attn_implementation="eager")
input_ids Tensor or None None Required when scorer="real". The (batch, seq) prefix token ids for the importance-scoring forward pass.
protect_boundary int or "auto" 0 Keep first+last N layers at FP16. Use 2 for Qwen-family models.
protected_layers set or None None Set of layer indices (0-based) to skip compression entirely (kept at FP16).
min_context_for_compression int 0 Skip compression when the prefill sequence is shorter than this value.
protected_positions Tensor or None None Token indices that are never evicted (e.g., VLM image spans).
truncate bool False Physically remove evicted tokens; saves real GPU memory but needs position_ids in generate()
soft_eviction bool False 1-bit instead of zero for evicted tokens
adaptive_context bool False Scale eviction rate down for short prefixes
compress_layers str "all" "global_only" skips sliding-window attention layers (Gemma4, Phi)
layer_bit_profile str "uniform" "graduated" gives boundary layers more bits
verbose bool True Print compression stats on exit

Yields a compressor object. After the first generate() inside the context, nq.last_mask is a (batch, seq) float attention mask for the evicted cache; pass it as attention_mask to follow-up generate() calls (multi-turn).

Results

PPL (quant-only, no eviction)

Mistral-7B-v0.1, wikitext-2, 1K-token prefix, paired per-chunk delta vs FP16.

Config bpe (honest) PPL delta vs FP16
FP16 16.000 0.000%
K4V2 pb=0 3.125 +0.19%
K3V2 pb=0 2.625 +0.276%
K2V2 pb=0 2.125 +0.95%

bpe = (key_bits + scale_bpe + value_bits + scale_bpe) / 2; scale_bpe = 16/128 = 0.125.

NIAH recall (quant-only, chat-template, Mistral-Inst-v0.3, A100-80GB)

Context FP16 K3V2 pb=0 K2V2 pb=0
12K 25/25 25/25 25/25
32K 25/25 25/25 24/25

32K K2V2 pb=0: one miss at depth=0.5 trial 0.

AQUA-KV comparison (Llama-3.1-8B-base, AQUA-iso protocol)

Under the AQUA-KV iso-protocol (arXiv:2501.19392, Table 2: wikitext-2 test, six sequences of 8192 tokens, prefix=1024), NexusQuant matches or improves on AQUA-KV's reported PPL deltas at lower or comparable raw bits-per-element. The pipeline is calibration-free; AQUA-KV requires per-architecture calibration.

Method Raw bpe PPL delta vs FP16 Calibration
NexusQuant K3V2 pb=0 2.625 +0.487% none
AQUA-KV (reported, Table 2) 3.06 +0.53% required
NexusQuant K2V2 pb=0 2.125 +1.649% none
AQUA-KV (reported, Table 2) 2.09 +1.96% required

References:

Two-tier compression

Tier What it does Compression PPL impact NIAH recall Use case
Quant-only E8 lattice VQ, no eviction ~6x ~0% (near-lossless) 100% Quality-critical apps
Light eviction E8 VQ + 25% eviction (real scorer) ~5.3x +0.20% 100% Balanced quality + compression
Aggressive eviction E8 VQ + 35-80% eviction 8-33x +0.3-5% degrades Memory-critical ("fits vs doesn't fit")

The NIAH cliff is sharp: recall is 100% at 25% eviction and drops at 35%+. Light eviction with the real attention scorer is the sweet spot for most deployments.

Supported architectures

Quality correlates strongly with KV head count. Models with 4+ KV heads are safe at K3V2. Models with fewer than 4 heads degrade.

Model family Status Caveats
Mistral-7B-v0.1 Validated Main PPL benchmark; K3V2 pb=0 +0.276% PPL (wikitext-2, n=161 chunks)
Mistral-Inst-v0.3 Validated K3V2 pb=0 +0.33% PPL; NIAH 40/40 at 4K and 8K, 100% to 32K on A100
Llama-3 / Llama-3.1 Validated Pass rope_scaling=model.config.rope_scaling; Llama-3.1 needs K4V2 minimum
Yi-6B / Yi-6B-Chat Validated K3V2 pb=0 +0.35% PPL; chat-template NIAH 5/5 at 4K
Qwen2.5-7B Validated Requires protect_boundary=2; boundary-off is catastrophic (+539x PPL)
Qwen3-8B Validated K3V2 pb=0 +0.38% PPL (iso-protocol, NF4 weights); boundary protection recommended
Gemma-2-2b Validated Best result: K3V2 pb=0 +0.05% PPL
Phi-3 / Phi family Validated Low PPL delta (~0.4% K3V2); set compress_layers="global_only" for SWA layers
Gemma4 (SWA models) Supported Use compress_layers="global_only"; 1 KV head = NIAH not reliable
Qwen2.5-1.5B Not recommended 2 KV heads: +5.04% PPL at K3V2. Below the safe threshold.
GPT-NeoX / GPT-J Not supported Interleaved RoPE not yet implemented
Encoder-decoder (T5, BART, Whisper) Not supported KV cache structure differs
Vision-language models Not tested Untested; may work for the language decoder

How it works

  1. Importance scoring - rank tokens by attention weight
  2. Token eviction (optional) - drop lowest-scoring tokens; always keep BOS and a recent sliding window
  3. RoPE removal - undo rotary embeddings on keys so they share a common subspace
  4. Hadamard rotation - spread energy uniformly across dimensions
  5. E8 lattice quantization - quantize 8-float groups onto the E8 root lattice. Asymmetric: 3-bit keys + 2-bit values
  6. Boundary protection - optionally keep first/last N layers at FP16 (mandatory for Qwen-family)
  7. Delta coding + zstd - consecutive tokens produce similar lattice indices; storing deltas then compressing with zstd yields another 2-3x

Why

Without NexusQuant With NexusQuant
128K context on 70B = ~42 GB KV cache (GQA) Same context = ~7 GB KV cache (6x quant-only)
KV cache competes with model weights for VRAM KV cache fits comfortably alongside weights
Long context needs multi-GPU or offloading Single GPU, single machine
Deploy a fine-tuned retrieval model One with block, no code changes

Quality presets

Measured on Mistral-7B, A100. Compression ratios include all overhead.

Preset Compression PPL degradation Config
high ~9x +0.35% K3V2 + 35% evict (A100, 3544-tok)
asym ~14x estimated <1% K3V2 + 60% evict (not fully validated)
balanced ~17x +0.82% K2V2 + 60% evict (A10G, 1664-tok)
max ~33x +2.13% K2V2 + 80% evict (A10G, 1664-tok)

PPL alone does not tell the full quality story. Eviction modes show degraded NIAH recall despite small PPL increases. Use quant-only when factual accuracy matters. See Limitations below.

Cross-architecture PPL results

Quant-only K3V2 pb=0 is near-lossless on all models with 4+ KV heads.

Model KV Heads K3V2 pb=0 PPL delta Notes
Gemma-2-2b 4 +0.05% Best result
Mistral-7B 8 +0.276% Main benchmark
Yi-6B 4 +0.35%
Mistral-Inst-v0.3 8 +0.33%
Llama-3.1-Inst 8 +0.64% P1 protocol; needs rope_scaling propagation
Qwen3-8B 8 +0.38% iso-protocol, NF4 weights
Qwen2.5-1.5B 2 +5.04% 2 KV heads = danger zone

Advanced options

Graduated layer bit profile - gives boundary layers (first/last 15%) higher precision (3-bit K+V) while middle layers use standard asymmetric (K3V2).

with nexusquant_evict(model, quality="high", layer_bit_profile="graduated"):
    output = model.generate(input_ids, max_new_tokens=200)

Hybrid model compression - for models like Gemma4 with sliding-window + global attention layers, only compress the global layers.

with nexusquant_evict(model, compress_layers="global_only"):
    output = model.generate(input_ids, max_new_tokens=200)

Compared to

Method Compression PPL degradation Training required Notes
NexusQuant K3V2 pb=0 6x +0.28% No Quant-only, no eviction
NexusQuant K2V2 pb=0 7.5x +0.95% No Quant-only, no eviction
TurboQuant+ 3.8-6.4x ~0-1% No Quant-only, no eviction
KVTC (NVIDIA) up to 20x <1% Yes (calibration)
CommVQ (Apple) ~8x ~0% Yes (retraining)
Palu 11x ~25% rel Yes (calibration)

Competitor numbers from their papers; not reproduced on our hardware.

Troubleshooting

Import resolves to the repo directory instead of the installed package

If you run python from inside a cloned copy of this repo, Python's import machinery finds the local nexusquant/ directory before the installed package. Run your script from a different directory, or use the full venv path:

cd /tmp
/path/to/your/venv/bin/python your_script.py

ImportError: cannot import name 'DynamicLayer' from 'transformers.cache_utils'

This means your transformers version is below 5.0. The eviction API (nexusquant_evict) requires transformers >= 5.0 and torch >= 2.4. The quant-only API (compress_kv_cache) works from transformers 4.46.

pip install "transformers>=5.0" "torch>=2.4"

torch version ceiling on Apple Intel (x86_64)

Apple Intel Macs are limited to torch <= 2.2.x. Torch 2.3+ requires Apple Silicon or Linux. On Intel Macs the eviction API (nexusquant_evict) is not available because it needs torch >= 2.4. The quant-only API works on torch 2.2.

transformers and torch version matrix

API transformers torch Notes
compress_kv_cache >= 4.46 >= 2.0 Tested on 4.46.3
nexusquant context manager >= 5.0 >= 2.4 Uses DynamicLayer hooks
nexusquant_evict context manager >= 5.0 >= 2.4 Uses DynamicLayer hooks

Limitations

  • Quality is text-dependent. Creative/narrative text degrades more than structured/technical text.
  • Short prefixes hurt. Prefixes under 500 tokens see more degradation.
  • Architecture-dependent boundary protection. Qwen-family models catastrophically fail without protect_boundary=2. Always test your specific model.
  • E8 quantization is CPU-bound. Triton GPU kernel is written but not yet benchmarked for latency.
  • Eviction hurts factual recall. NIAH benchmark shows degradation at 35%+ eviction. PPL hides this damage.
  • PPL is not a sufficient quality metric. Always validate with NIAH or downstream accuracy.
  • Results on 7B-class models primarily. 70B validation pending at scale.
  • Batch size > 1 is partially broken. NexusQuantSimple only compresses batch index 0.
  • Multi-turn chat is not supported. The hook compresses on every incoming prefill.
  • Speculative decoding is not supported.
  • KV cache offloading is not supported.
  • Encoder-decoder models (T5, BART, Whisper) are not supported.
  • Vision-language models are untested.
  • GGUF models are not supported.

Reproducibility: paper claims

Standalone runners under reproducibility/ reproduce the headline tables in the paper.

Script Paper table
reproducibility/niah_kv_canonical.py tab:niah_kv_canonical
reproducibility/commvq_head_to_head.py tab:commvq_head
reproducibility/ruler_13task.py tab:ruler_13task
reproducibility/entropy_coding_live.py tab:entropy_live
reproducibility/turboquant_validated.py TurboQuant H2H
reproducibility/niah_long_context.py tab:niah_long_ctx (4K/12K/32K)
reproducibility/kivi_upstream_gate.py tab:kivi_gate

See reproducibility/README.md for setup, GPU requirements, and reproduction commands.

Citation

@software{nexusquant2026,
  author  = {Marques, Jo\~{a}o Andr\'{e} Gomes},
  title   = {{NexusQuant}: Training-Free {KV} Cache Compression via {E8} Lattice Quantization and Attention-Aware Token Eviction},
  year    = {2026},
  url     = {https://github.com/jagmarques/nexusquant},
  license = {Apache-2.0},
}

License

Apache 2.0. See LICENSE.

Packages

 
 
 

Contributors

Languages