Skip to content

Commit 172fc85

Browse files
TheTomclaude
andcommitted
Merge signalnine/feature/turboquant-kv-cache (PR #3) — CUDA port
Resolved conflict in ggml-turbo-quant.c (kept both 4-bit centroids and CPU WHT). Updated ISWA build_attn to use new ggml_turbo_wht 5-arg signature. Removed redundant V inverse WHT from ISWA overload (now handled in build_attn_mha). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Co-Authored-By: tturney@psyguard.ai
2 parents 4cf7145 + 4c4511c commit 172fc85

34 files changed

Lines changed: 2267 additions & 76 deletions

common/arg.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,7 @@ const std::vector<ggml_type> kv_cache_types = {
387387
GGML_TYPE_IQ4_NL,
388388
GGML_TYPE_Q5_0,
389389
GGML_TYPE_Q5_1,
390+
GGML_TYPE_TURBO2_0,
390391
GGML_TYPE_TURBO3_0,
391392
GGML_TYPE_TURBO4_0,
392393
};

ggml/include/ggml.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,8 @@ extern "C" {
430430
GGML_TYPE_NVFP4 = 40, // NVFP4 (4 blocks, E4M3 scale)
431431
GGML_TYPE_TURBO3_0 = 41, // TurboQuant 3-bit KV cache: 2-bit PolarQuant + 1-bit QJL
432432
GGML_TYPE_TURBO4_0 = 42, // TurboQuant 4-bit KV cache: 3-bit PolarQuant + 1-bit QJL
433-
GGML_TYPE_COUNT = 43,
433+
GGML_TYPE_TURBO2_0 = 43, // TurboQuant 2-bit KV cache: 2-bit PolarQuant (no QJL)
434+
GGML_TYPE_COUNT = 44,
434435
};
435436

436437
// precision
@@ -2490,7 +2491,9 @@ extern "C" {
24902491
GGML_API struct ggml_tensor * ggml_turbo_wht(
24912492
struct ggml_context * ctx,
24922493
struct ggml_tensor * a,
2493-
int direction);
2494+
int direction,
2495+
int group_size, // 0 = auto (64 or 128 from ne[0])
2496+
struct ggml_tensor * scale); // NULL = no InnerQ scaling
24942497

24952498
// custom operators
24962499

ggml/src/ggml-common.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,18 @@ static_assert(sizeof(block_turbo4_0) == 2*sizeof(ggml_half) + QK_TURBO4*3/8 + QK
319319

320320
static_assert(QK_TURBO4 == 128, "turbo4 kernels assume QK_TURBO4 == 128");
321321

322+
// TurboQuant 2-bit: 2-bit PolarQuant indices only (no QJL)
323+
// Per block: norm(fp16) + 2-bit indices (8 bytes) = 10 bytes per 32 values
324+
// = 2.5 bits/value → 6.4× compression vs fp16
325+
// 4 centroids (Lloyd-Max for N(0, 1/128)): {-0.133462, -0.039994, 0.039994, 0.133462}
326+
#define QK_TURBO2 32 // Block size 32
327+
#define QK_TURBO2_GROUP 128 // rotation group size = head_dim
328+
typedef struct {
329+
ggml_half norm; // 2 bytes: corrected L2 norm
330+
uint8_t qs[QK_TURBO2 / 4]; // 8 bytes: 2-bit indices (4 per byte)
331+
} block_turbo2_0; // 10 bytes total
332+
static_assert(sizeof(block_turbo2_0) == sizeof(ggml_half) + QK_TURBO2/4, "wrong turbo2_0 block size/padding");
333+
322334
//
323335
// Super-block quantization structures
324336
//

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "ggml-cpu-impl.h"
88
#include "ggml-impl.h"
99
#include "quants.h"
10+
#include "ggml-quants.h"
1011
#include "ggml-threading.h"
1112
#include "unary-ops.h"
1213
#include "binary-ops.h"
@@ -204,6 +205,14 @@ typedef pthread_t ggml_thread_t;
204205
#include <TargetConditionals.h>
205206
#endif
206207

208+
// Forward declarations — defined below, after utility functions
209+
static void ggml_vec_dot_turbo3_0_f32(int n, float * GGML_RESTRICT s, size_t bs,
210+
const void * GGML_RESTRICT vx, size_t bx,
211+
const void * GGML_RESTRICT vy, size_t by, int nrc);
212+
static void ggml_vec_dot_turbo2_0_f32(int n, float * GGML_RESTRICT s, size_t bs,
213+
const void * GGML_RESTRICT vx, size_t bx,
214+
const void * GGML_RESTRICT vy, size_t by, int nrc);
215+
207216
static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
208217
[GGML_TYPE_F32] = {
209218
.from_float = (ggml_from_float_t) ggml_cpu_fp32_to_fp32,
@@ -393,6 +402,18 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
393402
[GGML_TYPE_I32] = {
394403
.from_float = (ggml_from_float_t) ggml_cpu_fp32_to_i32,
395404
},
405+
[GGML_TYPE_TURBO3_0] = {
406+
.from_float = (ggml_from_float_t) quantize_row_turbo3_0_ref,
407+
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_turbo3_0_f32,
408+
.vec_dot_type = GGML_TYPE_F32,
409+
.nrows = 1,
410+
},
411+
[GGML_TYPE_TURBO2_0] = {
412+
.from_float = (ggml_from_float_t) quantize_row_turbo2_0_ref,
413+
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_turbo2_0_f32,
414+
.vec_dot_type = GGML_TYPE_F32,
415+
.nrows = 1,
416+
},
396417
};
397418

398419
const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type) {
@@ -3318,6 +3339,46 @@ enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct g
33183339
return ggml_graph_compute(cgraph, &cplan);
33193340
}
33203341

3342+
// TurboQuant3 vec_dot: dequantize turbo3 block to f32, then dot with f32 operand.
3343+
// Used by CPU flash attention for models with D not supported by CUDA FA (e.g. D=192).
3344+
static void ggml_vec_dot_turbo3_0_f32(int n, float * GGML_RESTRICT s, size_t bs,
3345+
const void * GGML_RESTRICT vx, size_t bx,
3346+
const void * GGML_RESTRICT vy, size_t by, int nrc) {
3347+
GGML_ASSERT(nrc == 1);
3348+
GGML_UNUSED(bs); GGML_UNUSED(bx); GGML_UNUSED(by); GGML_UNUSED(nrc);
3349+
3350+
// Dequantize turbo3 to f32 temp buffer, then dot
3351+
float tmp[4096]; // max head_dim
3352+
GGML_ASSERT(n <= 4096);
3353+
ggml_get_type_traits(GGML_TYPE_TURBO3_0)->to_float(vx, tmp, n);
3354+
3355+
const float * y = (const float *)vy;
3356+
float sum = 0.0f;
3357+
for (int i = 0; i < n; i++) {
3358+
sum += tmp[i] * y[i];
3359+
}
3360+
*s = sum;
3361+
}
3362+
3363+
// TurboQuant2 vec_dot: dequantize turbo2 block to f32, then dot with f32 operand.
3364+
static void ggml_vec_dot_turbo2_0_f32(int n, float * GGML_RESTRICT s, size_t bs,
3365+
const void * GGML_RESTRICT vx, size_t bx,
3366+
const void * GGML_RESTRICT vy, size_t by, int nrc) {
3367+
GGML_ASSERT(nrc == 1);
3368+
GGML_UNUSED(bs); GGML_UNUSED(bx); GGML_UNUSED(by); GGML_UNUSED(nrc);
3369+
3370+
float tmp[4096];
3371+
GGML_ASSERT(n <= 4096);
3372+
ggml_get_type_traits(GGML_TYPE_TURBO2_0)->to_float(vx, tmp, n);
3373+
3374+
const float * y = (const float *)vy;
3375+
float sum = 0.0f;
3376+
for (int i = 0; i < n; i++) {
3377+
sum += tmp[i] * y[i];
3378+
}
3379+
*s = sum;
3380+
}
3381+
33213382
void ggml_cpu_fp32_to_fp32(const float * x, float * y, int64_t n) {
33223383
memcpy(y, x, n * sizeof(float));
33233384
}

ggml/src/ggml-cpu/ops.cpp

Lines changed: 57 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4926,6 +4926,14 @@ static void ggml_compute_forward_set_rows_f32(
49264926

49274927
ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float;
49284928

4929+
// For turbo types: communicate WHT group size to the quantize function via global
4930+
if (dst->type == GGML_TYPE_TURBO3_0 || dst->type == GGML_TYPE_TURBO4_0 || dst->type == GGML_TYPE_TURBO2_0) {
4931+
extern int turbo3_cpu_wht_group_size;
4932+
int gs = 0;
4933+
memcpy(&gs, dst->op_params, sizeof(int));
4934+
turbo3_cpu_wht_group_size = (gs == 64 || gs == 128) ? gs : 0;
4935+
}
4936+
49294937
for (int64_t i03 = 0; i03 < ne03; ++i03) {
49304938
for (int64_t i02 = 0; i02 < ne02; ++i02) {
49314939
for (int64_t i = ir0; i < ir1; ++i) {
@@ -10626,34 +10634,55 @@ static void ggml_compute_forward_turbo_wht_f32(
1062610634
const ggml_compute_params * params,
1062710635
ggml_tensor * dst) {
1062810636
const ggml_tensor * src = dst->src[0];
10637+
const ggml_tensor * scale_tensor = dst->src[1]; // InnerQ scale_inv (may be NULL)
1062910638
const float * src_data = (const float *) src->data;
1063010639
float * dst_data = (float *) dst->data;
10640+
const float * scale_inv = scale_tensor ? (const float *) scale_tensor->data : NULL;
1063110641

1063210642
int direction;
10633-
memcpy(&direction, dst->op_params, sizeof(int));
10643+
int group_size;
10644+
memcpy(&direction, dst->op_params + 0, sizeof(int));
10645+
memcpy(&group_size, dst->op_params + sizeof(int), sizeof(int));
1063410646

10635-
const float * s_first = (direction == 0) ? turbo_wht_s1 : turbo_wht_s2;
10636-
const float * s_second = (direction == 0) ? turbo_wht_s2 : turbo_wht_s1;
10647+
const int64_t head_dim = src->ne[0];
10648+
const int64_t n_heads = ggml_nelements(src) / head_dim;
10649+
const int64_t groups_per_head = head_dim / group_size;
10650+
const int tail_size = (int)(head_dim % group_size);
10651+
const int64_t n_groups = groups_per_head * n_heads;
1063710652

10638-
const int64_t n_total = ggml_nelements(src);
10639-
const int64_t n_groups = n_total / 128;
10653+
const float inv_sqrt = 1.0f / sqrtf((float)group_size);
1064010654

1064110655
// Parallel over groups
1064210656
const int64_t ith = params->ith;
1064310657
const int64_t nth = params->nth;
1064410658
const int64_t grp_start = (n_groups * ith) / nth;
1064510659
const int64_t grp_end = (n_groups * (ith + 1)) / nth;
1064610660

10661+
// Select sign arrays: for 64-group, use first 64 elements of the 128-element arrays
10662+
const float * s_first = (direction == 0) ? turbo_wht_s1 : turbo_wht_s2;
10663+
const float * s_second = (direction == 0) ? turbo_wht_s2 : turbo_wht_s1;
10664+
1064710665
for (int64_t g = grp_start; g < grp_end; g++) {
10648-
float x[128];
10649-
const float * in = src_data + g * 128;
10666+
const int64_t head_idx = g / groups_per_head;
10667+
const int64_t grp_in_head = g % groups_per_head;
10668+
const int64_t base = head_idx * head_dim + grp_in_head * group_size;
10669+
10670+
float x[128]; // max group_size
10671+
const float * in = src_data + base;
10672+
10673+
// InnerQ forward: apply scale_inv BEFORE signs+WHT (for Q pre-rotation)
10674+
if (direction == 0 && scale_inv != NULL) {
10675+
for (int i = 0; i < group_size; i++) x[i] = in[i] * scale_inv[i % group_size];
10676+
} else {
10677+
for (int i = 0; i < group_size; i++) x[i] = in[i];
10678+
}
1065010679

1065110680
// Apply first signs
10652-
for (int i = 0; i < 128; i++) x[i] = in[i] * s_first[i];
10681+
for (int i = 0; i < group_size; i++) x[i] *= s_first[i];
1065310682

10654-
// WHT butterfly (7 stages)
10655-
for (int h = 1; h < 128; h *= 2) {
10656-
for (int i = 0; i < 128; i += h * 2) {
10683+
// WHT butterfly (log2(group_size) stages)
10684+
for (int h = 1; h < group_size; h *= 2) {
10685+
for (int i = 0; i < group_size; i += h * 2) {
1065710686
for (int j = i; j < i + h; j++) {
1065810687
float a = x[j], b = x[j + h];
1065910688
x[j] = a + b;
@@ -10663,10 +10692,23 @@ static void ggml_compute_forward_turbo_wht_f32(
1066310692
}
1066410693

1066510694
// Normalize + second signs
10666-
const float inv_sqrt_128 = 0.08838834764831845f;
10667-
float * out = dst_data + g * 128;
10668-
for (int i = 0; i < 128; i++) {
10669-
out[i] = x[i] * inv_sqrt_128 * s_second[i];
10695+
float * out = dst_data + base;
10696+
for (int i = 0; i < group_size; i++) {
10697+
float val = x[i] * inv_sqrt * s_second[i];
10698+
// InnerQ inverse: apply scale_inv AFTER WHT+signs (for V un-rotation)
10699+
if (direction == 1 && scale_inv != NULL) {
10700+
val *= scale_inv[i % group_size];
10701+
}
10702+
out[i] = val;
10703+
}
10704+
}
10705+
10706+
// Copy tail elements unchanged (identity pass-through)
10707+
if (tail_size > 0 && ith == 0) {
10708+
const int64_t tail_offset = groups_per_head * group_size;
10709+
for (int64_t h = 0; h < n_heads; h++) {
10710+
const int64_t base = h * head_dim + tail_offset;
10711+
memcpy(dst_data + base, src_data + base, tail_size * sizeof(float));
1067010712
}
1067110713
}
1067210714
}

ggml/src/ggml-cuda/CMakeLists.txt

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,13 @@ if (CUDAToolkit_FOUND)
120120
template-instances/fattn-vec-instance-f16-f16.cu
121121
template-instances/fattn-vec-instance-q4_0-q4_0.cu
122122
template-instances/fattn-vec-instance-q8_0-q8_0.cu
123-
template-instances/fattn-vec-instance-bf16-bf16.cu)
123+
template-instances/fattn-vec-instance-bf16-bf16.cu
124+
template-instances/fattn-vec-instance-turbo3_0-turbo3_0.cu
125+
template-instances/fattn-vec-instance-turbo3_0-q8_0.cu
126+
template-instances/fattn-vec-instance-q8_0-turbo3_0.cu
127+
template-instances/fattn-vec-instance-turbo2_0-turbo2_0.cu
128+
template-instances/fattn-vec-instance-turbo2_0-q8_0.cu
129+
template-instances/fattn-vec-instance-q8_0-turbo2_0.cu)
124130
endif()
125131

126132
ggml_add_backend_library(ggml-cuda

ggml/src/ggml-cuda/convert.cu

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "convert.cuh"
22
#include "dequantize.cuh"
3+
#include "turbo-quant.cuh"
34

45
#include <cstdint>
56

@@ -756,6 +757,10 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
756757
return dequantize_row_mxfp4_cuda;
757758
case GGML_TYPE_NVFP4:
758759
return dequantize_row_nvfp4_cuda;
760+
case GGML_TYPE_TURBO3_0:
761+
return dequantize_block_cont_cuda<QK_TURBO3, QR_TURBO3, dequantize_turbo3_0>;
762+
case GGML_TYPE_TURBO2_0:
763+
return dequantize_block_cont_cuda<QK_TURBO2, QR_TURBO2, dequantize_turbo2_0>;
759764
case GGML_TYPE_F32:
760765
return convert_unary_cont_cuda<float>;
761766
case GGML_TYPE_BF16:
@@ -809,6 +814,10 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
809814
return dequantize_row_mxfp4_cuda;
810815
case GGML_TYPE_NVFP4:
811816
return dequantize_row_nvfp4_cuda;
817+
case GGML_TYPE_TURBO3_0:
818+
return dequantize_block_cont_cuda<QK_TURBO3, QR_TURBO3, dequantize_turbo3_0>;
819+
case GGML_TYPE_TURBO2_0:
820+
return dequantize_block_cont_cuda<QK_TURBO2, QR_TURBO2, dequantize_turbo2_0>;
812821
case GGML_TYPE_F16:
813822
return convert_unary_cont_cuda<half>;
814823
case GGML_TYPE_BF16:
@@ -832,6 +841,10 @@ to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) {
832841
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
833842
case GGML_TYPE_Q8_0:
834843
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
844+
case GGML_TYPE_TURBO3_0:
845+
return dequantize_block_cuda<QK_TURBO3, QR_TURBO3, dequantize_turbo3_0>;
846+
case GGML_TYPE_TURBO2_0:
847+
return dequantize_block_cuda<QK_TURBO2, QR_TURBO2, dequantize_turbo2_0>;
835848
case GGML_TYPE_BF16:
836849
return convert_unary_cuda<nv_bfloat16>;
837850
default:
@@ -874,6 +887,10 @@ to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type) {
874887
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
875888
case GGML_TYPE_Q8_0:
876889
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
890+
case GGML_TYPE_TURBO3_0:
891+
return dequantize_block_cuda<QK_TURBO3, QR_TURBO3, dequantize_turbo3_0>;
892+
case GGML_TYPE_TURBO2_0:
893+
return dequantize_block_cuda<QK_TURBO2, QR_TURBO2, dequantize_turbo2_0>;
877894
case GGML_TYPE_BF16:
878895
return convert_unary_cuda<nv_bfloat16, float>;
879896
default:

ggml/src/ggml-cuda/dequantize.cuh

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "common.cuh"
2+
#include "turbo-quant.cuh"
23

34
static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int64_t ib, const int iqs, float2 & v){
45
const block_q4_0 * x = (const block_q4_0 *) vx;
@@ -75,3 +76,20 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in
7576
v.x *= d;
7677
v.y *= d;
7778
}
79+
80+
// Turbo3: 3-bit PolarQuant (2-bit qs + 1-bit sign), block size 32
81+
// iqs is the element index within the block (even), produces elements iqs and iqs+1
82+
static __device__ __forceinline__ void dequantize_turbo3_0(const void * vx, const int64_t ib, const int iqs, float2 & v){
83+
const block_turbo3_0 * x = (const block_turbo3_0 *) vx;
84+
const float norm = __half2float(x[ib].norm);
85+
v.x = turbo3_dequant_element(&x[ib], iqs + 0, norm);
86+
v.y = turbo3_dequant_element(&x[ib], iqs + 1, norm);
87+
}
88+
89+
// Turbo2: 2-bit PolarQuant (2-bit qs only, no sign), block size 32
90+
static __device__ __forceinline__ void dequantize_turbo2_0(const void * vx, const int64_t ib, const int iqs, float2 & v){
91+
const block_turbo2_0 * x = (const block_turbo2_0 *) vx;
92+
const float norm = __half2float(x[ib].norm);
93+
v.x = turbo2_dequant_element(&x[ib], iqs + 0, norm);
94+
v.y = turbo2_dequant_element(&x[ib], iqs + 1, norm);
95+
}

0 commit comments

Comments
 (0)