Skip to content

Implement PolarQuant quantizer (polar_quant.py) — Algorithm 1 #4

Description

@TheTom

Summary

Implement PolarQuant (Algorithm 1 from the paper) — MSE-optimized vector quantization via random rotation + optimal scalar quantization per coordinate.

Depends on: #2 (rotation.py), #3 (codebook.py)

Paper Reference

  • Paper: arXiv 2504.19874, Algorithm 1 (page 5)
  • This is the MSE-optimized quantizer. It does NOT preserve inner products (that requires QJL on top).

Algorithm (from paper)

Setup:
  1. Generate random rotation matrix Π ∈ R^(d×d)
  2. Construct codebook: centroids c_1...c_(2^b) minimizing MSE

Quantize(x):
  1. y ← Π @ x                              # Random rotation
  2. idx_j ← argmin_k |y_j - c_k| for j∈[d]  # Nearest centroid per coordinate
  3. Return idx                               # b-bit integers

Dequantize(idx):
  1. ỹ_j ← c_(idx_j) for j∈[d]              # Look up centroids
  2. x̃ ← Π^T @ ỹ                            # Inverse rotation
  3. Return x̃

Requirements

PolarQuant class

  • __init__(d, bit_width, seed) — creates rotation matrix and codebook
  • quantize(x) — single vector (d,) or batch (batch, d) → integer indices
  • dequantize(indices) — indices → reconstructed vectors
  • quantize_and_residual(x) — returns (indices, residual) where residual = x - dequantize(indices)
    • This method is needed by TurboQuant's QJL stage

Implementation details

  • Rotation applied as: y = (Π @ x.T).T for batch
  • Inverse rotation: x̃ = (Π.T @ ỹ.T).T
  • Indices must be in range [0, 2^bit_width)
  • Support bit_width = 1, 2, 3, 4

Tests Required (write FIRST)

  1. Round-trip MSE within paper bounds (parametrized over bit_width × dimension):
    • b=1: avg MSE < 0.36 × 2 (2× slack for finite d)
    • b=2: avg MSE < 0.117 × 2
    • b=3: avg MSE < 0.03 × 2
    • Test at d = {64, 128, 256}
    • Use 500+ random unit vectors per test
  2. MSE decreases with bit-width: MSE(b=1) > MSE(b=2) > MSE(b=3)
  3. Zero vector behavior: quantize(zeros) should not crash; reconstruction norm should be small
  4. Deterministic: same seed + same input → same output
  5. Batch matches single: quantize(batch)[i] == quantize(batch[i])
  6. Indices in valid range: all indices ∈ [0, 2^b)
  7. Residual identity: residual == x - dequantize(quantize(x))
  8. Large vectors: vectors with norm >> 1 should still work (not just unit vectors)
  9. Reconstruction preserves direction: cosine similarity between x and dequantize(quantize(x)) should be high for unit vectors

Acceptance Criteria

  • Tests written and reviewed with codex BEFORE implementation
  • Avg MSE within 2× of paper bounds for all tested (d, b) combinations
  • MSE monotonically decreases with bit-width
  • Batch and single-vector results are identical
  • codex-review on implementation
  • Roast review on implementation
  • Coverage >95% for polar_quant.py

Metadata

Metadata

Assignees

No one assigned

    Labels

    P1Core product worktype:algorithmCore algorithm implementation

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions