Skip to content

fix: turbo4 on non-128 heads + KV state serialization with padded tensors (issue #28)#30

Merged
TheTom merged 1 commit intoTheTom:feature/turboquant-kv-cachefrom
signalnine:feature/turboquant-kv-cache
Mar 29, 2026
Merged

fix: turbo4 on non-128 heads + KV state serialization with padded tensors (issue #28)#30
TheTom merged 1 commit intoTheTom:feature/turboquant-kv-cachefrom
signalnine:feature/turboquant-kv-cache

Conversation

@signalnine
Copy link
Copy Markdown

Summary

Two fixes for turbo KV cache on models with non-128-aligned heads
(GLM-4.7 Flash head_dim=576):

  1. Context init check: llama-context.cpp rejected turbo4
    (QK=128) at head_dim=576 before the KV cache zero-padding code
    ran. Now computes padded head_dim first (576→640).

  2. KV state serialization: state_write_data / state_read_data
    used hparams.n_embd_k_gqa (576) for ggml_row_size, but turbo
    types pad to 640. Assertion failure on prompt cache save during
    llama-server slot reuse. Now uses k->ne[0] / v->ne[0]
    (actual padded tensor width).

Bug 2 only triggers in llama-server with prompt cache enabled on
multi-request slot reuse — missed by llama-cli --single-turn tests.

Test plan

  • GLM-4.7 Flash turbo4 loads (was crashing)
  • Server smoke test: 3 request types × 3 turbo types (new test script)
  • Build clean

🤖 Generated with Claude Code

…follow-up)

state_write_data and state_read_data used hparams.n_embd_k_gqa (576)
for ggml_row_size, but turbo types zero-pad to 640. For turbo4
(QK=128), 576 % 128 != 0 → ggml_row_size assertion failure during
prompt cache save on llama-server slot reuse.

Fix: use k->ne[0] / v->ne[0] (actual padded tensor width) instead of
hparams values in all four serialization paths (K write, K read,
V write, V read).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@TheTom
Copy link
Copy Markdown
Owner

TheTom commented Mar 29, 2026

Tested on both M5 Max and M2 Pro.

M5 Max (128GB):

  • PPL: 6.1756 (matches known turbo3 baseline exactly, no regression)
  • Decode: pp512 2705.5 t/s, tg128 76.18 t/s (normal)
  • Build clean, no warnings

M2 Pro (32GB):

  • Session save/restore: PASS on turbo2, turbo3, turbo4 (bit-identical output across save/restore/single-seq)
  • Decode: pp512 327-329 t/s, tg128 24-26 t/s (normal, no regression)
  • Build clean

The serialization fix is confirmed correct. Merging.

@TheTom TheTom merged commit 1b7165f into TheTom:feature/turboquant-kv-cache Mar 29, 2026
8 of 44 checks passed
@TheTom
Copy link
Copy Markdown
Owner

TheTom commented Mar 29, 2026

Full regression suite on M5 Max post-merge. All configs match known baselines exactly:

Config PPL vs q8_0 pp512 t/s tg128 t/s pp32768 t/s
q8_0 6.1109 baseline 2792 83.60 1182
turbo4 6.1250 +0.23% 2738 77.64 1109
turbo3 6.1756 +1.06% 2747 77.26
turbo2 6.5066 +6.48% 2735 77.05
q8_0/turbo4 6.1088 -0.03% 2777 78.12

Zero regressions across PPL, prefill, decode, and long-context (32K). CL is clean.

shtaylor pushed a commit to shtaylor/llama-cpp-turboquant that referenced this pull request Mar 30, 2026
…Tom#30

Perplexity benchmarking reveals catastrophic quality failure:
- f16: 6.121, q8_0: 6.111, q4_0: 6.142
- turbo3: 165.6 (27× worse)

Speed benchmarks were meaningless — fast garbage.
Root cause investigation needed before any quality claims.

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
shtaylor pushed a commit to shtaylor/llama-cpp-turboquant that referenced this pull request Mar 30, 2026
1. V cache returns rotated-space values (cosine=0.02 vs correct 0.987)
2. dynamic_cast to llama_kv_cache_context fails for MoE models
   (uses llama_memory_hybrid_context, not kv_cache_context)
   → Q rotation and V inverse rotation NEVER executed

Fix: store rotation tensors in llm_graph_context, not KV cache.
Or access through hybrid memory interface.

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
shtaylor pushed a commit to shtaylor/llama-cpp-turboquant that referenced this pull request Mar 30, 2026
…eTom#31 TheTom#30

ROOT CAUSE: pre-rotate-queries never executed because:
1. Q ne[0]=256 (GQA concatenated heads), rotation matrix ne[0]=128
2. mctx dynamic_cast failed for MoE hybrid memory

FIX: put inverse WHT rotation back in dequantize_full_block.
This is slower (10.7 tok/s vs 77.7) but produces CORRECT results.

PERPLEXITY RESULTS:
- f16:     6.121
- q8_0:    6.111
- q4_0:    6.142
- turbo3:  6.194 (+1.2% vs q8_0) ✅

The speed optimization (pre-rotate-queries) needs to be reimplemented
to work with GQA head layout and hybrid memory types.

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
shtaylor pushed a commit to shtaylor/llama-cpp-turboquant that referenced this pull request Mar 30, 2026
Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
shtaylor pushed a commit to shtaylor/llama-cpp-turboquant that referenced this pull request Mar 30, 2026
Quality confirmed: PPL 6.194 (+1.4% of q8_0)
Speed: 10.7 tok/s (inverse rotation in dequant, no pre-rotate-queries)
Previous speed claims (51-77 tok/s) were invalid — measured garbage output speed.

Key lessons documented for future reference.

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
mihai-chiorean pushed a commit to mihai-chiorean/turbo3-cuda that referenced this pull request Mar 31, 2026
…Tom#30

Perplexity benchmarking reveals catastrophic quality failure:
- f16: 6.121, q8_0: 6.111, q4_0: 6.142
- turbo3: 165.6 (27× worse)

Speed benchmarks were meaningless — fast garbage.
Root cause investigation needed before any quality claims.

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
mihai-chiorean pushed a commit to mihai-chiorean/turbo3-cuda that referenced this pull request Mar 31, 2026
1. V cache returns rotated-space values (cosine=0.02 vs correct 0.987)
2. dynamic_cast to llama_kv_cache_context fails for MoE models
   (uses llama_memory_hybrid_context, not kv_cache_context)
   → Q rotation and V inverse rotation NEVER executed

Fix: store rotation tensors in llm_graph_context, not KV cache.
Or access through hybrid memory interface.

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
mihai-chiorean pushed a commit to mihai-chiorean/turbo3-cuda that referenced this pull request Mar 31, 2026
…eTom#31 TheTom#30

ROOT CAUSE: pre-rotate-queries never executed because:
1. Q ne[0]=256 (GQA concatenated heads), rotation matrix ne[0]=128
2. mctx dynamic_cast failed for MoE hybrid memory

FIX: put inverse WHT rotation back in dequantize_full_block.
This is slower (10.7 tok/s vs 77.7) but produces CORRECT results.

PERPLEXITY RESULTS:
- f16:     6.121
- q8_0:    6.111
- q4_0:    6.142
- turbo3:  6.194 (+1.2% vs q8_0) ✅

The speed optimization (pre-rotate-queries) needs to be reimplemented
to work with GQA head layout and hybrid memory types.

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
mihai-chiorean pushed a commit to mihai-chiorean/turbo3-cuda that referenced this pull request Mar 31, 2026
Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
mihai-chiorean pushed a commit to mihai-chiorean/turbo3-cuda that referenced this pull request Mar 31, 2026
Quality confirmed: PPL 6.194 (+1.4% of q8_0)
Speed: 10.7 tok/s (inverse rotation in dequant, no pre-rotate-queries)
Previous speed claims (51-77 tok/s) were invalid — measured garbage output speed.

Key lessons documented for future reference.

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
TheTom added a commit that referenced this pull request Apr 2, 2026
Perplexity benchmarking reveals catastrophic quality failure:
- f16: 6.121, q8_0: 6.111, q4_0: 6.142
- turbo3: 165.6 (27× worse)

Speed benchmarks were meaningless — fast garbage.
Root cause investigation needed before any quality claims.

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
TheTom added a commit that referenced this pull request Apr 2, 2026
1. V cache returns rotated-space values (cosine=0.02 vs correct 0.987)
2. dynamic_cast to llama_kv_cache_context fails for MoE models
   (uses llama_memory_hybrid_context, not kv_cache_context)
   → Q rotation and V inverse rotation NEVER executed

Fix: store rotation tensors in llm_graph_context, not KV cache.
Or access through hybrid memory interface.

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
TheTom added a commit that referenced this pull request Apr 2, 2026
…#30

ROOT CAUSE: pre-rotate-queries never executed because:
1. Q ne[0]=256 (GQA concatenated heads), rotation matrix ne[0]=128
2. mctx dynamic_cast failed for MoE hybrid memory

FIX: put inverse WHT rotation back in dequantize_full_block.
This is slower (10.7 tok/s vs 77.7) but produces CORRECT results.

PERPLEXITY RESULTS:
- f16:     6.121
- q8_0:    6.111
- q4_0:    6.142
- turbo3:  6.194 (+1.2% vs q8_0) ✅

The speed optimization (pre-rotate-queries) needs to be reimplemented
to work with GQA head layout and hybrid memory types.

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
TheTom added a commit that referenced this pull request Apr 2, 2026
Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
TheTom added a commit that referenced this pull request Apr 2, 2026
Quality confirmed: PPL 6.194 (+1.4% of q8_0)
Speed: 10.7 tok/s (inverse rotation in dequant, no pre-rotate-queries)
Previous speed claims (51-77 tok/s) were invalid — measured garbage output speed.

Key lessons documented for future reference.

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
TheTom added a commit that referenced this pull request Apr 2, 2026
Perplexity benchmarking reveals catastrophic quality failure:
- f16: 6.121, q8_0: 6.111, q4_0: 6.142
- turbo3: 165.6 (27× worse)

Speed benchmarks were meaningless — fast garbage.
Root cause investigation needed before any quality claims.

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
TheTom added a commit that referenced this pull request Apr 2, 2026
1. V cache returns rotated-space values (cosine=0.02 vs correct 0.987)
2. dynamic_cast to llama_kv_cache_context fails for MoE models
   (uses llama_memory_hybrid_context, not kv_cache_context)
   → Q rotation and V inverse rotation NEVER executed

Fix: store rotation tensors in llm_graph_context, not KV cache.
Or access through hybrid memory interface.

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
TheTom added a commit that referenced this pull request Apr 2, 2026
…#30

ROOT CAUSE: pre-rotate-queries never executed because:
1. Q ne[0]=256 (GQA concatenated heads), rotation matrix ne[0]=128
2. mctx dynamic_cast failed for MoE hybrid memory

FIX: put inverse WHT rotation back in dequantize_full_block.
This is slower (10.7 tok/s vs 77.7) but produces CORRECT results.

PERPLEXITY RESULTS:
- f16:     6.121
- q8_0:    6.111
- q4_0:    6.142
- turbo3:  6.194 (+1.2% vs q8_0) ✅

The speed optimization (pre-rotate-queries) needs to be reimplemented
to work with GQA head layout and hybrid memory types.

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
TheTom added a commit that referenced this pull request Apr 2, 2026
Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
TheTom added a commit that referenced this pull request Apr 2, 2026
Quality confirmed: PPL 6.194 (+1.4% of q8_0)
Speed: 10.7 tok/s (inverse rotation in dequant, no pre-rotate-queries)
Previous speed claims (51-77 tok/s) were invalid — measured garbage output speed.

Key lessons documented for future reference.

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants