feat: zero-pad non-128 heads + turbo4 CUDA port + cross-type FA (issues #13, #25)#24
Merged
TheTom merged 5 commits intoTheTom:feature/turboquant-kv-cachefrom Mar 29, 2026
Conversation
…back)
Instead of falling back to q8_0 for models with non-128-aligned heads,
pad each head to the next multiple of 128 before WHT rotation. The
padded elements are zero and don't affect dot products since WHT
preserves inner products: <WHT(Q_pad), WHT(K_pad)> = <Q, K>.
This keeps turbo compression active on DeepSeek2, GLM-4.7 Flash, and
other non-128 head_dim models. Padding overhead is wasted bits on the
zero-padded quantized elements:
head_dim=192 → 256 (33% overhead, still 3.5x compression)
head_dim=576 → 640 (11% overhead, still 4.1x compression)
Changes:
- llama-kv-cache.cpp: allocate K/V cache at padded dimension, pad
k_cur/v_cur per-head via ggml_pad before set_rows, return padded
views from get_k/get_v
- llama-graph.cpp: pad Q per-head before turbo_wht, extract original
V head_dim from attention output after inverse WHT. All three
build_attn variants (KV, K-only/MLA, ISWA) updated.
PPL (wikitext-2, 512 context, 8 chunks):
DeepSeek (192→256): 11.61 vs 9.90 baseline (+17%)
Previously: 344,304 with 64-group WHT (catastrophic)
GLM-4.7 (576→640): 9.11 vs 14.97 baseline
Qwen3.5 (128): 6.31 unchanged
Tested: 16/16 coherence, NIAH 9-10/11 at 32K, CPU-only works.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
GLM-4.7 Flash with turbo3 KV cache pads head_dim 576→640 for WHT.
D=640 had no CUDA FA kernel, falling back to CPU FA at 37 t/s.
Added MMA flash attention configs for D=640 (DKQ=640, DV=512),
using identical tile sizes as D=576 (nbatch_K2=288, nbatch_V2=256).
ncols1 capped at 2 for D=640: at ncols1≥4 the Q shared memory
(ncols × (DKQ/2+4) × 4 = 83KB at ncols=64) exceeds the per-block
limit. ncols1≤2 keeps total shared memory under 80KB.
Changes:
- fattn-mma-f16.cuh: D=640 config entries + extern declarations
(ncols2=16 only, ncols1∈{1,2,4})
- fattn.cu: case 640 in kernel selection + simplified MMA dispatch
(always ncols2=16, ncols1 capped at 2)
- fattn-tile.cuh/cu: D=640 tile configs + dispatch
- 3 MMA template instances + 1 tile instance
Tested:
GLM-4.7 Flash: 192 t/s (was 37), PPL 17.05 (vs 14.97 f16)
Qwen3.5: 222 t/s, PPL 6.31 (unchanged)
DeepSeek: 143 t/s, PPL 11.61 (unchanged)
16/16 coherence, NIAH 9-10/11 at 32K
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…ug 2) Mixed turbo3-K/turbo2-V and turbo2-K/turbo3-V had no CUDA FA kernel instances, causing ~11x prefill regression (falling back to CPU FA). Added VEC template instances for both cross-type pairs at D=64/128/256. Updated the mixed-type guard in get_best_fattn_kernel to allow any combination of turbo2, turbo3, and q8_0. Tested: turbo3/turbo2 and turbo2/turbo3 both run at full CUDA VEC speed (~170 t/s prefill, ~221 t/s decode on Qwen3.5 35B). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…Tom#25 bug 1 Ports GGML_TYPE_TURBO4_0 to CUDA using the 4-bit PolarQuant format (16 centroids, nibble-packed, no QJL). Previously turbo4 crashed on CUDA with "cannot run the operation (SET_ROWS)". Changes TURBO4_USE_4BIT default from Metal-only to all backends. The 4-bit format (16 centroids) has better quality than the legacy 3-bit+QJL format and is simpler to implement (no residual projection). Full CUDA stack: - turbo-quant.cuh: 4-bit centroids, midpoints, nearest-centroid, dequant element, per-block quantize - set-rows.cu: k_set_rows_turbo4 kernel (128 threads, WHT rotation, 4-bit quantize, nibble pack via warp shuffle, corrected norm) - dequantize.cuh + convert.cu: turbo4 to f16/f32 - fattn-common.cuh: vec_dot_KQ_turbo4 + dequantize_V_turbo4 - fattn-vec.cuh + fattn.cu: VEC dispatch + all cross-type instances (turbo4×turbo4, turbo4×q8_0, turbo4×turbo3, turbo4×turbo2) - ggml-cpu.c: CPU FA vec_dot for turbo4 PPL (Qwen3.5, wikitext-2): 6.23 (+0.8% vs q8_0) at 3.8× compression Speed: 217 t/s decode (comparable to turbo3 222 t/s) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
… (issue TheTom#28) The block-size divisibility check in llama-context.cpp rejected turbo4 on GLM-4.7 Flash (head_dim=576, QK_TURBO4=128, 576%128≠0) before the KV cache zero-padding code could run. Fix: for turbo types, compute the padded head_dim (ceil to 128) before the divisibility check, matching what llama-kv-cache.cpp actually does. Tested: GLM-4.7 Flash turbo4 loads and runs at 193 t/s. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Owner
|
Verified no regression on existing Metal paths on M5 Max:
The |
6 tasks
Merged
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Three features in this PR:
SET_ROWSturbo3:GGML_ASSERT(ne00 % QK_TURBO3_GROUP == 0)fails when row width is 576 (e.g. GLM-4.7 Flash / deepseek2 K heads) #13)1. Zero-pad non-128 heads
Enables turbo KV compression on models with non-128-aligned head dims
(DeepSeek2, GLM-4.7 Flash) by zero-padding each head to the next
multiple of 128 before WHT rotation.
Math:
<WHT(Q_padded), WHT(K_padded)> = <Q, K>since padded elements are zero.Includes D=640 MMA flash attention configs for GLM-4.7 Flash
(37 → 192 t/s, ncols1 capped at 2 due to shared memory limits).
2. turbo4 CUDA port
Full CUDA stack for GGML_TYPE_TURBO4_0 using 4-bit PolarQuant (16
centroids, nibble-packed). Changes
TURBO4_USE_4BITdefault to allbackends — better quality than legacy 3-bit+QJL.
SET_ROWS kernel, VEC FA vec_dot/dequantize_V, convert path, CPU FA
fallback, all cross-type instances (turbo4 × q8_0/turbo3/turbo2).
3. Cross-type FA instances
Added VEC FA instances for turbo3/turbo2 and turbo2/turbo3 pairs.
Updated mixed-type guard to allow any combination of turbo2, turbo3,
turbo4, and q8_0. Fixes ~11x prefill regression reported in issue #25.
Benchmarks (Qwen3.5 35B, RTX 5090)
PPL (wikitext-2, ctx=512, 8 chunks):
Non-128 head models (zero-pad):
Cross-type speed (all combos at full CUDA VEC, ~220 t/s decode)
Test plan
🤖 Generated with Claude Code