@@ -4926,6 +4926,14 @@ static void ggml_compute_forward_set_rows_f32(
49264926
49274927 ggml_from_float_t const from_float = ggml_get_type_traits_cpu (dst->type )->from_float ;
49284928
4929+ // For turbo types: communicate WHT group size to the quantize function via global
4930+ if (dst->type == GGML_TYPE_TURBO3_0 || dst->type == GGML_TYPE_TURBO4_0 || dst->type == GGML_TYPE_TURBO2_0) {
4931+ extern int turbo3_cpu_wht_group_size;
4932+ int gs = 0 ;
4933+ memcpy (&gs, dst->op_params , sizeof (int ));
4934+ turbo3_cpu_wht_group_size = (gs == 64 || gs == 128 ) ? gs : 0 ;
4935+ }
4936+
49294937 for (int64_t i03 = 0 ; i03 < ne03; ++i03) {
49304938 for (int64_t i02 = 0 ; i02 < ne02; ++i02) {
49314939 for (int64_t i = ir0; i < ir1; ++i) {
@@ -10626,34 +10634,55 @@ static void ggml_compute_forward_turbo_wht_f32(
1062610634 const ggml_compute_params * params,
1062710635 ggml_tensor * dst) {
1062810636 const ggml_tensor * src = dst->src [0 ];
10637+ const ggml_tensor * scale_tensor = dst->src [1 ]; // InnerQ scale_inv (may be NULL)
1062910638 const float * src_data = (const float *) src->data ;
1063010639 float * dst_data = (float *) dst->data ;
10640+ const float * scale_inv = scale_tensor ? (const float *) scale_tensor->data : NULL ;
1063110641
1063210642 int direction;
10633- memcpy (&direction, dst->op_params , sizeof (int ));
10643+ int group_size;
10644+ memcpy (&direction, dst->op_params + 0 , sizeof (int ));
10645+ memcpy (&group_size, dst->op_params + sizeof (int ), sizeof (int ));
1063410646
10635- const float * s_first = (direction == 0 ) ? turbo_wht_s1 : turbo_wht_s2;
10636- const float * s_second = (direction == 0 ) ? turbo_wht_s2 : turbo_wht_s1;
10647+ const int64_t head_dim = src->ne [0 ];
10648+ const int64_t n_heads = ggml_nelements (src) / head_dim;
10649+ const int64_t groups_per_head = head_dim / group_size;
10650+ const int tail_size = (int )(head_dim % group_size);
10651+ const int64_t n_groups = groups_per_head * n_heads;
1063710652
10638- const int64_t n_total = ggml_nelements (src);
10639- const int64_t n_groups = n_total / 128 ;
10653+ const float inv_sqrt = 1 .0f / sqrtf ((float )group_size);
1064010654
1064110655 // Parallel over groups
1064210656 const int64_t ith = params->ith ;
1064310657 const int64_t nth = params->nth ;
1064410658 const int64_t grp_start = (n_groups * ith) / nth;
1064510659 const int64_t grp_end = (n_groups * (ith + 1 )) / nth;
1064610660
10661+ // Select sign arrays: for 64-group, use first 64 elements of the 128-element arrays
10662+ const float * s_first = (direction == 0 ) ? turbo_wht_s1 : turbo_wht_s2;
10663+ const float * s_second = (direction == 0 ) ? turbo_wht_s2 : turbo_wht_s1;
10664+
1064710665 for (int64_t g = grp_start; g < grp_end; g++) {
10648- float x[128 ];
10649- const float * in = src_data + g * 128 ;
10666+ const int64_t head_idx = g / groups_per_head;
10667+ const int64_t grp_in_head = g % groups_per_head;
10668+ const int64_t base = head_idx * head_dim + grp_in_head * group_size;
10669+
10670+ float x[128 ]; // max group_size
10671+ const float * in = src_data + base;
10672+
10673+ // InnerQ forward: apply scale_inv BEFORE signs+WHT (for Q pre-rotation)
10674+ if (direction == 0 && scale_inv != NULL ) {
10675+ for (int i = 0 ; i < group_size; i++) x[i] = in[i] * scale_inv[i % group_size];
10676+ } else {
10677+ for (int i = 0 ; i < group_size; i++) x[i] = in[i];
10678+ }
1065010679
1065110680 // Apply first signs
10652- for (int i = 0 ; i < 128 ; i++) x[i] = in[i] * s_first[i];
10681+ for (int i = 0 ; i < group_size ; i++) x[i] *= s_first[i];
1065310682
10654- // WHT butterfly (7 stages)
10655- for (int h = 1 ; h < 128 ; h *= 2 ) {
10656- for (int i = 0 ; i < 128 ; i += h * 2 ) {
10683+ // WHT butterfly (log2(group_size) stages)
10684+ for (int h = 1 ; h < group_size ; h *= 2 ) {
10685+ for (int i = 0 ; i < group_size ; i += h * 2 ) {
1065710686 for (int j = i; j < i + h; j++) {
1065810687 float a = x[j], b = x[j + h];
1065910688 x[j] = a + b;
@@ -10663,10 +10692,23 @@ static void ggml_compute_forward_turbo_wht_f32(
1066310692 }
1066410693
1066510694 // Normalize + second signs
10666- const float inv_sqrt_128 = 0 .08838834764831845f ;
10667- float * out = dst_data + g * 128 ;
10668- for (int i = 0 ; i < 128 ; i++) {
10669- out[i] = x[i] * inv_sqrt_128 * s_second[i];
10695+ float * out = dst_data + base;
10696+ for (int i = 0 ; i < group_size; i++) {
10697+ float val = x[i] * inv_sqrt * s_second[i];
10698+ // InnerQ inverse: apply scale_inv AFTER WHT+signs (for V un-rotation)
10699+ if (direction == 1 && scale_inv != NULL ) {
10700+ val *= scale_inv[i % group_size];
10701+ }
10702+ out[i] = val;
10703+ }
10704+ }
10705+
10706+ // Copy tail elements unchanged (identity pass-through)
10707+ if (tail_size > 0 && ith == 0 ) {
10708+ const int64_t tail_offset = groups_per_head * group_size;
10709+ for (int64_t h = 0 ; h < n_heads; h++) {
10710+ const int64_t base = h * head_dim + tail_offset;
10711+ memcpy (dst_data + base, src_data + base, tail_size * sizeof (float ));
1067010712 }
1067110713 }
1067210714}
0 commit comments