Skip to content

Commit 16e93d5

Browse files
author
Marcel
committed
ggml: add TQ3_0 (TurboQuant 3-bit) KV cache quantization type
Implements a 3.5 bits/value KV cache quantization type based on the TurboQuant/PolarQuant/QJL papers from Google Research. Algorithm: - Per-block Walsh-Hadamard Transform (WHT32) with fixed sign flips makes any input distribution approximately Gaussian (by CLT) - 2-bit Max-Lloyd optimal codebook {-1.510, -0.453, +0.453, +1.510} tuned for Gaussian achieves near-optimal MSE - 1-bit QJL residual signs for error correction - Per-block FP16 scale factor Block format: 14 bytes / 32 values = 3.5 bits/value (4.6x vs F16) qs[8]: 2-bit codebook indices (4 per byte) qr[4]: QJL residual signs (1 per bit) gamma: FP16 per-block scale Fused MMVQ kernel: Since WHT is orthogonal, dot(q,k) = dot(WHT(q), WHT(k)). Apply WHT to Q8_1 query values inside the fused vec_dot kernel (int32 butterfly), compute dot product in rotated space. No dequantize+MUL_MAT fallback needed — speed matches Q4_0. Results (Qwen3.5-0.8B-Q5_K_M, wikitext-2, Radeon 8060S): F16: PPL = 20.05, tg128 = 181.8 t/s Q4_0: PPL = 20.14 (+0.4%), tg128 = 179.1 t/s TQ3_0: PPL = 21.21 (+5.8%), tg128 = 177.9 t/s References: TurboQuant (arXiv:2504.19874) — ICLR 2026 PolarQuant (arXiv:2502.02617) — AISTATS 2026 QJL (arXiv:2406.03482)
1 parent 914eb5f commit 16e93d5

20 files changed

Lines changed: 445 additions & 3 deletions

common/arg.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,7 @@ const std::vector<ggml_type> kv_cache_types = {
387387
GGML_TYPE_IQ4_NL,
388388
GGML_TYPE_Q5_0,
389389
GGML_TYPE_Q5_1,
390+
GGML_TYPE_TQ3_0,
390391
};
391392

392393
static ggml_type kv_cache_type_from_str(const std::string & s) {

ggml/include/ggml.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,8 @@ extern "C" {
428428
// GGML_TYPE_IQ4_NL_8_8 = 38,
429429
GGML_TYPE_MXFP4 = 39, // MXFP4 (1 block)
430430
GGML_TYPE_NVFP4 = 40, // NVFP4 (4 blocks, E4M3 scale)
431-
GGML_TYPE_COUNT = 41,
431+
GGML_TYPE_TQ3_0 = 41, // TurboQuant 3-bit polar + QJL (no per-block scale)
432+
GGML_TYPE_COUNT = 42,
432433
};
433434

434435
// precision

ggml/src/ggml-common.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,21 @@ typedef struct {
266266
} block_tq2_0;
267267
static_assert(sizeof(block_tq2_0) == sizeof(ggml_half) + QK_K / 4, "wrong tq2_0 block size/padding");
268268

269+
// TurboQuant 3-bit quantization (3.5 bpw)
270+
// Per TurboQuant paper (Algorithm 2: TurboQuant_prod), ICLR 2026
271+
// Each block of 32 values is quantized as:
272+
// - 2-bit MSE codebook indices (after random rotation Π·x)
273+
// - 1-bit QJL residual signs (sign(S·r) where r = x - dequant_mse(quant_mse(x)))
274+
// - FP16 residual norm ||r||₂ for QJL scaling
275+
// Requires per-model rotation matrices Π and S (stored externally)
276+
#define QK_TQ3_0 32
277+
typedef struct {
278+
uint8_t qs[QK_TQ3_0 / 4]; // 2-bit codebook indices, 32 × 2 bits = 8 bytes
279+
uint8_t qr[QK_TQ3_0 / 8]; // QJL residual signs, 32 × 1 bit = 4 bytes
280+
ggml_half gamma; // ||residual||₂ for QJL correction scaling
281+
} block_tq3_0;
282+
static_assert(sizeof(block_tq3_0) == QK_TQ3_0/4 + QK_TQ3_0/8 + sizeof(ggml_half), "wrong tq3_0 block size/padding");
283+
269284
//
270285
// Super-block quantization structures
271286
//

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,10 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
390390
.vec_dot_type = GGML_TYPE_Q8_K,
391391
.nrows = 1,
392392
},
393+
[GGML_TYPE_TQ3_0] = {
394+
.from_float = quantize_row_tq3_0,
395+
.nrows = 1,
396+
},
393397
[GGML_TYPE_I32] = {
394398
.from_float = (ggml_from_float_t) ggml_cpu_fp32_to_i32,
395399
},

ggml/src/ggml-cpu/ggml-cpu.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,11 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
448448
op->type != GGML_TYPE_IQ1_S &&
449449
op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float
450450
case GGML_OP_MUL_MAT:
451-
return src1->type == GGML_TYPE_F32 || src1->type == ggml_get_type_traits_cpu(src0->type)->vec_dot_type;
451+
{
452+
const auto * traits = ggml_get_type_traits_cpu(src0->type);
453+
return traits->vec_dot != NULL &&
454+
(src1->type == GGML_TYPE_F32 || src1->type == traits->vec_dot_type);
455+
}
452456
case GGML_OP_SOFT_MAX_BACK: {
453457
if (op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type != GGML_TYPE_F32) {
454458
return false;
@@ -466,6 +470,9 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
466470
case GGML_OP_OUT_PROD:
467471
return (src0->type == GGML_TYPE_F32 || (ggml_is_quantized(src0->type) && src0->ne[2] == src1->ne[2] && src0->ne[3] == src1->ne[3])) &&
468472
src1->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
473+
case GGML_OP_FLASH_ATTN_EXT:
474+
// K type must have vec_dot for CPU flash attention
475+
return ggml_get_type_traits_cpu(src1->type)->vec_dot != NULL;
469476
default:
470477
return true;
471478
}

ggml/src/ggml-cpu/ops.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,7 @@ void ggml_compute_forward_add(
678678
case GGML_TYPE_Q6_K:
679679
case GGML_TYPE_TQ1_0:
680680
case GGML_TYPE_TQ2_0:
681+
case GGML_TYPE_TQ3_0:
681682
case GGML_TYPE_IQ2_XXS:
682683
case GGML_TYPE_IQ2_XS:
683684
case GGML_TYPE_IQ3_XXS:
@@ -1128,6 +1129,7 @@ void ggml_compute_forward_add1(
11281129
case GGML_TYPE_Q6_K:
11291130
case GGML_TYPE_TQ1_0:
11301131
case GGML_TYPE_TQ2_0:
1132+
case GGML_TYPE_TQ3_0:
11311133
case GGML_TYPE_IQ2_XXS:
11321134
case GGML_TYPE_IQ2_XS:
11331135
case GGML_TYPE_IQ3_XXS:
@@ -1257,6 +1259,7 @@ void ggml_compute_forward_acc(
12571259
case GGML_TYPE_Q6_K:
12581260
case GGML_TYPE_TQ1_0:
12591261
case GGML_TYPE_TQ2_0:
1262+
case GGML_TYPE_TQ3_0:
12601263
case GGML_TYPE_IQ2_XXS:
12611264
case GGML_TYPE_IQ2_XS:
12621265
case GGML_TYPE_IQ3_XXS:
@@ -4345,6 +4348,7 @@ void ggml_compute_forward_out_prod(
43454348
case GGML_TYPE_Q6_K:
43464349
case GGML_TYPE_TQ1_0:
43474350
case GGML_TYPE_TQ2_0:
4351+
case GGML_TYPE_TQ3_0:
43484352
case GGML_TYPE_IQ2_XXS:
43494353
case GGML_TYPE_IQ2_XS:
43504354
case GGML_TYPE_IQ3_XXS:
@@ -4621,6 +4625,7 @@ void ggml_compute_forward_set(
46214625
case GGML_TYPE_Q6_K:
46224626
case GGML_TYPE_TQ1_0:
46234627
case GGML_TYPE_TQ2_0:
4628+
case GGML_TYPE_TQ3_0:
46244629
case GGML_TYPE_IQ2_XXS:
46254630
case GGML_TYPE_IQ2_XS:
46264631
case GGML_TYPE_IQ3_XXS:
@@ -4844,6 +4849,7 @@ void ggml_compute_forward_get_rows(
48444849
case GGML_TYPE_Q6_K:
48454850
case GGML_TYPE_TQ1_0:
48464851
case GGML_TYPE_TQ2_0:
4852+
case GGML_TYPE_TQ3_0:
48474853
case GGML_TYPE_IQ2_XXS:
48484854
case GGML_TYPE_IQ2_XS:
48494855
case GGML_TYPE_IQ3_XXS:
@@ -5569,6 +5575,7 @@ void ggml_compute_forward_clamp(
55695575
case GGML_TYPE_Q6_K:
55705576
case GGML_TYPE_TQ1_0:
55715577
case GGML_TYPE_TQ2_0:
5578+
case GGML_TYPE_TQ3_0:
55725579
case GGML_TYPE_IQ2_XXS:
55735580
case GGML_TYPE_IQ2_XS:
55745581
case GGML_TYPE_IQ3_XXS:

ggml/src/ggml-cpu/quants.c

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,12 @@ void quantize_row_tq2_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy,
108108
quantize_row_tq2_0_ref(x, y, k);
109109
}
110110

111+
void quantize_row_tq3_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
112+
assert(k % QK_TQ3_0 == 0);
113+
block_tq3_0 * GGML_RESTRICT y = vy;
114+
quantize_row_tq3_0_ref(x, y, k);
115+
}
116+
111117
//===================================== Q8_K ==============================================
112118

113119
void quantize_row_q8_K_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {

ggml/src/ggml-cpu/quants.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in
3131

3232
void quantize_row_tq1_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
3333
void quantize_row_tq2_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
34+
void quantize_row_tq3_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
3435

3536
void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
3637
void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);

ggml/src/ggml-cuda/common.cuh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,6 +1029,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ3_S> {
10291029
static constexpr int qi = QI3_S;
10301030
};
10311031

1032+
template<>
1033+
struct ggml_cuda_type_traits<GGML_TYPE_TQ3_0> {
1034+
static constexpr int qk = QK_TQ3_0; // 32
1035+
static constexpr int qr = 1;
1036+
static constexpr int qi = QK_TQ3_0 / 4; // 8
1037+
};
1038+
10321039
//////////////////////
10331040

10341041
struct ggml_cuda_device_info {

ggml/src/ggml-cuda/convert.cu

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,50 @@ static __global__ void dequantize_block_mxfp4(const void * __restrict__ vx, dst_
486486
}
487487
}
488488

489+
// TurboQuant TQ3_0: 2-bit codebook dequantization + inverse WHT
490+
// Dequantize to rotated space, then apply inverse WHT32 cooperatively
491+
template<typename dst_t>
492+
static __global__ void dequantize_block_tq3_0(const void * __restrict__ vx, dst_t * __restrict__ yy) {
493+
const float centroids[4] = { -1.510f, -0.4528f, 0.4528f, 1.510f };
494+
const int8_t signs[32] = {
495+
+1, -1, +1, +1, -1, -1, +1, -1, +1, +1, -1, +1, -1, +1, -1, -1,
496+
+1, -1, -1, +1, +1, -1, +1, -1, -1, +1, +1, +1, -1, -1, +1, -1
497+
};
498+
499+
const int64_t i = blockIdx.x;
500+
const block_tq3_0 * x = (const block_tq3_0 *)vx;
501+
const int tid = threadIdx.x;
502+
if (tid >= 32) return;
503+
504+
const float d = __half2float(x[i].gamma);
505+
506+
// Step 1: Each thread dequantizes its value (in rotated space)
507+
const int byte_idx = tid / 4;
508+
const int bit_shift = 2 * (tid % 4);
509+
const int idx = (x[i].qs[byte_idx] >> bit_shift) & 3;
510+
511+
__shared__ float shmem[32];
512+
shmem[tid] = d * centroids[idx];
513+
__syncthreads();
514+
515+
// Step 2: Cooperative inverse WHT (5 butterfly stages)
516+
for (int step = 1; step < 32; step <<= 1) {
517+
int partner = tid ^ step; // butterfly partner
518+
float a = shmem[tid];
519+
float b = shmem[partner];
520+
__syncthreads();
521+
if (tid < partner) {
522+
shmem[tid] = a + b;
523+
shmem[partner] = a - b;
524+
}
525+
__syncthreads();
526+
}
527+
528+
// Step 3: Normalize and undo sign flips
529+
const float inv_sqrt32 = 0.17677669529663688f;
530+
yy[i * QK_TQ3_0 + tid] = shmem[tid] * inv_sqrt32 * signs[tid];
531+
}
532+
489533
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
490534
static void dequantize_block_cuda(const void * vx, dst_t * y,
491535
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
@@ -617,6 +661,12 @@ static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t
617661
dequantize_block_mxfp4<<<nb, 32, 0, stream>>>(vx, y);
618662
}
619663

664+
template<typename dst_t>
665+
static void dequantize_row_tq3_0_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
666+
const int nb = k / QK_TQ3_0;
667+
dequantize_block_tq3_0<<<nb, 32, 0, stream>>>(vx, y);
668+
}
669+
620670
template <typename src_t, typename dst_t>
621671
static __global__ void convert_unary(
622672
const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01,
@@ -715,6 +765,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
715765
return dequantize_row_iq3_s_cuda;
716766
case GGML_TYPE_MXFP4:
717767
return dequantize_row_mxfp4_cuda;
768+
case GGML_TYPE_TQ3_0:
769+
return dequantize_row_tq3_0_cuda;
718770
case GGML_TYPE_F32:
719771
return convert_unary_cont_cuda<float>;
720772
case GGML_TYPE_BF16:
@@ -766,6 +818,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
766818
return dequantize_row_iq3_s_cuda;
767819
case GGML_TYPE_MXFP4:
768820
return dequantize_row_mxfp4_cuda;
821+
case GGML_TYPE_TQ3_0:
822+
return dequantize_row_tq3_0_cuda;
769823
case GGML_TYPE_F16:
770824
return convert_unary_cont_cuda<half>;
771825
case GGML_TYPE_BF16:

0 commit comments

Comments
 (0)