@@ -486,6 +486,50 @@ static __global__ void dequantize_block_mxfp4(const void * __restrict__ vx, dst_
486486 }
487487}
488488
489+ // TurboQuant TQ3_0: 2-bit codebook dequantization + inverse WHT
490+ // Dequantize to rotated space, then apply inverse WHT32 cooperatively
491+ template <typename dst_t >
492+ static __global__ void dequantize_block_tq3_0 (const void * __restrict__ vx, dst_t * __restrict__ yy) {
493+ const float centroids[4 ] = { -1 .510f , -0 .4528f , 0 .4528f , 1 .510f };
494+ const int8_t signs[32 ] = {
495+ +1 , -1 , +1 , +1 , -1 , -1 , +1 , -1 , +1 , +1 , -1 , +1 , -1 , +1 , -1 , -1 ,
496+ +1 , -1 , -1 , +1 , +1 , -1 , +1 , -1 , -1 , +1 , +1 , +1 , -1 , -1 , +1 , -1
497+ };
498+
499+ const int64_t i = blockIdx .x ;
500+ const block_tq3_0 * x = (const block_tq3_0 *)vx;
501+ const int tid = threadIdx .x ;
502+ if (tid >= 32 ) return ;
503+
504+ const float d = __half2float (x[i].gamma );
505+
506+ // Step 1: Each thread dequantizes its value (in rotated space)
507+ const int byte_idx = tid / 4 ;
508+ const int bit_shift = 2 * (tid % 4 );
509+ const int idx = (x[i].qs [byte_idx] >> bit_shift) & 3 ;
510+
511+ __shared__ float shmem[32 ];
512+ shmem[tid] = d * centroids[idx];
513+ __syncthreads ();
514+
515+ // Step 2: Cooperative inverse WHT (5 butterfly stages)
516+ for (int step = 1 ; step < 32 ; step <<= 1 ) {
517+ int partner = tid ^ step; // butterfly partner
518+ float a = shmem[tid];
519+ float b = shmem[partner];
520+ __syncthreads ();
521+ if (tid < partner) {
522+ shmem[tid] = a + b;
523+ shmem[partner] = a - b;
524+ }
525+ __syncthreads ();
526+ }
527+
528+ // Step 3: Normalize and undo sign flips
529+ const float inv_sqrt32 = 0 .17677669529663688f ;
530+ yy[i * QK_TQ3_0 + tid] = shmem[tid] * inv_sqrt32 * signs[tid];
531+ }
532+
489533template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t >
490534static void dequantize_block_cuda (const void * vx, dst_t * y,
491535 const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
@@ -617,6 +661,12 @@ static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t
617661 dequantize_block_mxfp4<<<nb, 32 , 0 , stream>>> (vx, y);
618662}
619663
664+ template <typename dst_t >
665+ static void dequantize_row_tq3_0_cuda (const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
666+ const int nb = k / QK_TQ3_0;
667+ dequantize_block_tq3_0<<<nb, 32 , 0 , stream>>> (vx, y);
668+ }
669+
620670template <typename src_t , typename dst_t >
621671static __global__ void convert_unary (
622672 const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01,
@@ -715,6 +765,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
715765 return dequantize_row_iq3_s_cuda;
716766 case GGML_TYPE_MXFP4:
717767 return dequantize_row_mxfp4_cuda;
768+ case GGML_TYPE_TQ3_0:
769+ return dequantize_row_tq3_0_cuda;
718770 case GGML_TYPE_F32:
719771 return convert_unary_cont_cuda<float >;
720772 case GGML_TYPE_BF16:
@@ -766,6 +818,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
766818 return dequantize_row_iq3_s_cuda;
767819 case GGML_TYPE_MXFP4:
768820 return dequantize_row_mxfp4_cuda;
821+ case GGML_TYPE_TQ3_0:
822+ return dequantize_row_tq3_0_cuda;
769823 case GGML_TYPE_F16:
770824 return convert_unary_cont_cuda<half>;
771825 case GGML_TYPE_BF16:
0 commit comments