Skip to content

feat: zero-pad non-128 heads + turbo4 CUDA port + cross-type FA (issues #13, #25)#24

Merged
TheTom merged 5 commits intoTheTom:feature/turboquant-kv-cachefrom
signalnine:feature/turboquant-kv-cache
Mar 29, 2026
Merged

feat: zero-pad non-128 heads + turbo4 CUDA port + cross-type FA (issues #13, #25)#24
TheTom merged 5 commits intoTheTom:feature/turboquant-kv-cachefrom
signalnine:feature/turboquant-kv-cache

Conversation

@signalnine
Copy link
Copy Markdown

@signalnine signalnine commented Mar 29, 2026

Summary

Three features in this PR:

  1. Zero-pad non-128-aligned heads for full 7-stage WHT (issue CUDA SET_ROWS turbo3: GGML_ASSERT(ne00 % QK_TURBO3_GROUP == 0) fails when row width is 576 (e.g. GLM-4.7 Flash / deepseek2 K heads) #13)
  2. CUDA port of turbo4 4-bit KV cache (issue [bug] turbo4 crashes on CUDA (SET_ROWS unported) + mixed K/V types cause ~11x prefill regression #25 bug 1)
  3. turbo2/turbo3 cross-type VEC FA instances (issue [bug] turbo4 crashes on CUDA (SET_ROWS unported) + mixed K/V types cause ~11x prefill regression #25 bug 2)

1. Zero-pad non-128 heads

Enables turbo KV compression on models with non-128-aligned head dims
(DeepSeek2, GLM-4.7 Flash) by zero-padding each head to the next
multiple of 128 before WHT rotation.

  • head_dim=192 → 256 (33% overhead, 3.5× compression)
  • head_dim=576 → 640 (11% overhead, 4.1× compression)

Math: <WHT(Q_padded), WHT(K_padded)> = <Q, K> since padded elements are zero.

Includes D=640 MMA flash attention configs for GLM-4.7 Flash
(37 → 192 t/s, ncols1 capped at 2 due to shared memory limits).

2. turbo4 CUDA port

Full CUDA stack for GGML_TYPE_TURBO4_0 using 4-bit PolarQuant (16
centroids, nibble-packed). Changes TURBO4_USE_4BIT default to all
backends — better quality than legacy 3-bit+QJL.

SET_ROWS kernel, VEC FA vec_dot/dequantize_V, convert path, CPU FA
fallback, all cross-type instances (turbo4 × q8_0/turbo3/turbo2).

3. Cross-type FA instances

Added VEC FA instances for turbo3/turbo2 and turbo2/turbo3 pairs.
Updated mixed-type guard to allow any combination of turbo2, turbo3,
turbo4, and q8_0. Fixes ~11x prefill regression reported in issue #25.

Benchmarks (Qwen3.5 35B, RTX 5090)

PPL (wikitext-2, ctx=512, 8 chunks):

Type PPL vs q8_0 Compression
q8_0 6.18 baseline 2.0x
turbo4 6.23 +0.8% 3.8x
turbo3 6.31 +2.1% 4.6x
turbo2 6.69 +8.3% 6.4x

Non-128 head models (zero-pad):

Model head_dim f16 PPL turbo3 PPL Speed
DeepSeek-V2 192→256 9.90 11.61 143 t/s
GLM-4.7 Flash 576→640 14.97 17.05 192 t/s

Cross-type speed (all combos at full CUDA VEC, ~220 t/s decode)

Test plan

  • 16/16 coherence (4 models × 4 KV combos)
  • NIAH 9-10/11 at 32K
  • PPL regression check on Qwen3.5
  • GLM-4.7 Flash D=640 MMA FA
  • turbo4 coherence + PPL
  • All cross-type combos at CUDA speed
  • CPU-only path

🤖 Generated with Claude Code

signalnine and others added 2 commits March 29, 2026 07:26
…back)

Instead of falling back to q8_0 for models with non-128-aligned heads,
pad each head to the next multiple of 128 before WHT rotation. The
padded elements are zero and don't affect dot products since WHT
preserves inner products: <WHT(Q_pad), WHT(K_pad)> = <Q, K>.

This keeps turbo compression active on DeepSeek2, GLM-4.7 Flash, and
other non-128 head_dim models. Padding overhead is wasted bits on the
zero-padded quantized elements:
  head_dim=192 → 256 (33% overhead, still 3.5x compression)
  head_dim=576 → 640 (11% overhead, still 4.1x compression)

Changes:
- llama-kv-cache.cpp: allocate K/V cache at padded dimension, pad
  k_cur/v_cur per-head via ggml_pad before set_rows, return padded
  views from get_k/get_v
- llama-graph.cpp: pad Q per-head before turbo_wht, extract original
  V head_dim from attention output after inverse WHT. All three
  build_attn variants (KV, K-only/MLA, ISWA) updated.

PPL (wikitext-2, 512 context, 8 chunks):
  DeepSeek (192→256): 11.61 vs 9.90 baseline (+17%)
    Previously: 344,304 with 64-group WHT (catastrophic)
  GLM-4.7 (576→640):  9.11 vs 14.97 baseline
  Qwen3.5 (128):      6.31 unchanged

Tested: 16/16 coherence, NIAH 9-10/11 at 32K, CPU-only works.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
GLM-4.7 Flash with turbo3 KV cache pads head_dim 576→640 for WHT.
D=640 had no CUDA FA kernel, falling back to CPU FA at 37 t/s.

Added MMA flash attention configs for D=640 (DKQ=640, DV=512),
using identical tile sizes as D=576 (nbatch_K2=288, nbatch_V2=256).

ncols1 capped at 2 for D=640: at ncols1≥4 the Q shared memory
(ncols × (DKQ/2+4) × 4 = 83KB at ncols=64) exceeds the per-block
limit. ncols1≤2 keeps total shared memory under 80KB.

Changes:
- fattn-mma-f16.cuh: D=640 config entries + extern declarations
  (ncols2=16 only, ncols1∈{1,2,4})
- fattn.cu: case 640 in kernel selection + simplified MMA dispatch
  (always ncols2=16, ncols1 capped at 2)
- fattn-tile.cuh/cu: D=640 tile configs + dispatch
- 3 MMA template instances + 1 tile instance

Tested:
  GLM-4.7 Flash: 192 t/s (was 37), PPL 17.05 (vs 14.97 f16)
  Qwen3.5: 222 t/s, PPL 6.31 (unchanged)
  DeepSeek: 143 t/s, PPL 11.61 (unchanged)
  16/16 coherence, NIAH 9-10/11 at 32K

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
signalnine and others added 2 commits March 29, 2026 08:34
…ug 2)

Mixed turbo3-K/turbo2-V and turbo2-K/turbo3-V had no CUDA FA kernel
instances, causing ~11x prefill regression (falling back to CPU FA).

Added VEC template instances for both cross-type pairs at D=64/128/256.
Updated the mixed-type guard in get_best_fattn_kernel to allow any
combination of turbo2, turbo3, and q8_0.

Tested: turbo3/turbo2 and turbo2/turbo3 both run at full CUDA VEC
speed (~170 t/s prefill, ~221 t/s decode on Qwen3.5 35B).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…Tom#25 bug 1

Ports GGML_TYPE_TURBO4_0 to CUDA using the 4-bit PolarQuant format
(16 centroids, nibble-packed, no QJL). Previously turbo4 crashed on
CUDA with "cannot run the operation (SET_ROWS)".

Changes TURBO4_USE_4BIT default from Metal-only to all backends.
The 4-bit format (16 centroids) has better quality than the legacy
3-bit+QJL format and is simpler to implement (no residual projection).

Full CUDA stack:
- turbo-quant.cuh: 4-bit centroids, midpoints, nearest-centroid,
  dequant element, per-block quantize
- set-rows.cu: k_set_rows_turbo4 kernel (128 threads, WHT rotation,
  4-bit quantize, nibble pack via warp shuffle, corrected norm)
- dequantize.cuh + convert.cu: turbo4 to f16/f32
- fattn-common.cuh: vec_dot_KQ_turbo4 + dequantize_V_turbo4
- fattn-vec.cuh + fattn.cu: VEC dispatch + all cross-type instances
  (turbo4×turbo4, turbo4×q8_0, turbo4×turbo3, turbo4×turbo2)
- ggml-cpu.c: CPU FA vec_dot for turbo4

PPL (Qwen3.5, wikitext-2): 6.23 (+0.8% vs q8_0) at 3.8× compression
Speed: 217 t/s decode (comparable to turbo3 222 t/s)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@signalnine signalnine changed the title feat: zero-pad non-128 heads + D=640 MMA FA for turbo KV on MLA models feat: zero-pad non-128 heads + turbo4 CUDA port + cross-type FA (issues #13, #25) Mar 29, 2026
… (issue TheTom#28)

The block-size divisibility check in llama-context.cpp rejected turbo4
on GLM-4.7 Flash (head_dim=576, QK_TURBO4=128, 576%128≠0) before the
KV cache zero-padding code could run.

Fix: for turbo types, compute the padded head_dim (ceil to 128) before
the divisibility check, matching what llama-kv-cache.cpp actually does.

Tested: GLM-4.7 Flash turbo4 loads and runs at 193 t/s.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@TheTom
Copy link
Copy Markdown
Owner

TheTom commented Mar 29, 2026

Verified no regression on existing Metal paths on M5 Max:

  • phi-4-Q8_0 q8_0/q8_0 PPL: 4.6901 (exact match to baseline)
  • phi-4-Q8_0 turbo3/turbo3 PPL: 4.8855 (exact match)
  • turbo3 decode: 29.45 t/s, prefill: 620 t/s (healthy)

The q8_0/turbo4 failure on this branch is expected — mixed q8_0 × turbo Metal FA kernels are from a separate commit on my side (965a6ca). Next step is to combine this PR with the existing Metal asymmetric support on a temporary integration branch, run a short Metal sanity suite, and then merge the combined result.

@TheTom TheTom merged commit 2dd602a into TheTom:feature/turboquant-kv-cache Mar 29, 2026
1 check passed
@Dubascudes Dubascudes mentioned this pull request Apr 1, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants