Skip to content

feat: turbo3/turbo2 mixed KV cache type support (CUDA)#29

Closed
seanrasch wants to merge 1 commit intoTheTom:feature/turboquant-kv-cachefrom
seanrasch:feat/mixed-turbo3-turbo2
Closed

feat: turbo3/turbo2 mixed KV cache type support (CUDA)#29
seanrasch wants to merge 1 commit intoTheTom:feature/turboquant-kv-cachefrom
seanrasch:feat/mixed-turbo3-turbo2

Conversation

@seanrasch
Copy link
Copy Markdown

Summary

Enable asymmetric KV cache precision: turbo3 K + turbo2 V.

  • K cache at 3-bit (turbo3) — more sensitive, affects shared softmax denominator
  • V cache at 2-bit (turbo2) — less sensitive, contributes linearly weighted by attention
  • Compression: 5.33x vs f16 (up from 4.57x with turbo3/turbo3)
  • Usage: --cache-type-k turbo3 --cache-type-v turbo2

K/V asymmetric precision is well-established — KIVI, QAQ, and KVTuner all demonstrate
that V tolerates lower precision than K. Each V vector is consumed in proportion to
its attention weight; at long context most V positions receive near-zero attention
(this is why sparse V dequant works). Compressing them further costs almost nothing.

Benchmarks (RTX 3080 Ti, SM 86)

Quality (wikitext-2 PPL)

Validated across two architectures (Qwen, Llama) and two weight quants (Q4_K_M, Q8_0):

Model f16/f16 turbo3/turbo3 turbo3/turbo2 Δ (t3t3 → t3t2)
Qwen3 8B (Q4_K_M) 10.44 11.57 11.88 +0.31
Qwen3.5 9B (Q4_K_M) 8.30 8.42 8.69 +0.27
NeuralDaredevil 8B (Q8_0) 8.08 8.32 8.58 +0.26

turbo3/turbo2 adds +0.26 to +0.31 PPL over turbo3/turbo3 — consistent across
models and architectures. This is 22-27% of the cost already accepted for turbo3
compression.

Throughput (3 runs each)

Qwen3 8B (Q4_K_M):

Test turbo3/turbo3 turbo3/turbo2 Delta
pp512 4569 t/s 4477 t/s -2.0%
pp8192 3990 t/s 3888 t/s -2.6%
pp32768 2569 t/s 2583 t/s +0.5%
tg128 113.9 t/s 112.8 t/s -1.0%

Qwen3.5 9B (Q4_K_M):

Test turbo3/turbo3 turbo3/turbo2 Delta
pp512 3889 t/s 3876 t/s -0.3%
pp8192 3871 t/s 3857 t/s -0.4%
pp32768 3445 t/s 3444 t/s 0.0%
tg128 102.0 t/s 102.2 t/s +0.2%

NeuralDaredevil 8B (Q8_0):

Test turbo3/turbo3 turbo3/turbo2 Delta
pp512 4768 t/s 4786 t/s +0.4%
pp8192 4174 t/s 4168 t/s -0.2%
pp32768 2813 t/s 2827 t/s +0.5%
tg128 84.0 t/s 84.3 t/s +0.3%

Throughput impact ranges from -2.6% to +0.5% on prefill, -1% to +0.3% on decode.
At 32K context (where the memory savings matter most) all three models show neutral
or positive throughput.

Memory

Cache Bytes per KV position (d=128, 32 heads) Compression
f16/f16 16,384 1.0x
turbo3/turbo3 3,584 4.57x
turbo3/turbo2 3,072 5.33x

14.3% less KV memory than turbo3/turbo3.

Changes

  • Add fattn-vec-instance-turbo3_0-turbo2_0.cu and turbo2_0-turbo3_0.cu template instances
  • Add extern declarations in fattn-vec.cuh
  • Add dispatch cases and mixed-type validation in fattn.cu
  • Add source files to CMakeLists.txt

5 files changed, +38 lines. No algorithmic changes. All existing code paths unchanged.

Test plan

  • Builds clean (cmake CUDA SM 86)
  • --cache-type-k turbo3 --cache-type-v turbo2 runs without errors
  • PPL validated on 3 models, 2 architectures (Qwen + Llama)
  • Throughput benchmarked at 512, 8192, 32768 context (3 models)
  • No regression on existing turbo3/turbo3 or f16/f16 paths
  • NIAH validation (requesting TheTom verify)

🤖 Generated with Claude Code

Enable asymmetric KV cache precision with --cache-type-k turbo3
--cache-type-v turbo2. K cache uses 3-bit turbo3 (more sensitive —
affects shared softmax denominator), V cache uses 2-bit turbo2 (less
sensitive — contributes linearly, weighted by attention).

Compression: 5.33x vs f16 (up from 4.57x with turbo3/turbo3).

Benchmarked on Qwen3 8B Q4_K_M, RTX 3080 Ti (SM 86):

  Quality (wikitext-2 PPL):
    f16/f16:       10.44
    turbo3/turbo3: 11.57 (+1.13)
    turbo3/turbo2: 11.88 (+1.44, only +0.31 over turbo3/turbo3)

  Throughput:
    turbo3/turbo3 pp32768: 2569 t/s, tg128: 113.9 t/s
    turbo3/turbo2 pp32768: 2583 t/s, tg128: 112.8 t/s

K/V asymmetric precision is well-established in the literature (KIVI,
QAQ, KVTuner) — V is structurally less sensitive because each V vector
is consumed linearly in proportion to its attention weight, while K
vectors affect the global softmax distribution.

Changes:
- Add fattn-vec template instances for turbo3_0/turbo2_0 combinations
- Add extern declarations in fattn-vec.cuh
- Add dispatch cases in fattn.cu (vec kernel + mixed-type validation)
- Add source files to CMakeLists.txt

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
vonempalmeolmos pushed a commit to vonempalmeolmos/llama-cpp-turboquant that referenced this pull request Mar 29, 2026
Codex post-commit review found:
1. TURBO_D was QK_TURBO3 (now 32) — broke turbo4 C array sizes
2. SET_ROWS kernel turbo3-specific but instantiated for turbo4
3. Tail block drop for non-128 head dims

Fixed TheTom#3 (TURBO_D). TheTom#1 and TheTom#2 don't affect turbo3+dk128 path.

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
vonempalmeolmos pushed a commit to vonempalmeolmos/llama-cpp-turboquant that referenced this pull request Mar 29, 2026
…ling (Issue TheTom#29)

Three bugs from the block-size-32 refactor:

1. kernel_set_rows_turbo hardcoded turbo3 packing for turbo4 — split into
   separate kernel_set_rows_turbo3 and kernel_set_rows_turbo4 kernels.
   turbo4 now correctly does 3-bit PolarQuant + QJL residual correction.

2. Integer division in n_groups = nk0 / blocks_per_group silently dropped
   tail blocks for non-128-aligned head dims (e.g. dk=192). Added ceiling
   division with tail-group bounds checking in turbo3, and GGML_ASSERT in
   WHT dispatch to catch non-128-aligned tensors.

3. TURBO_D constant was semantically coupled to QK_TURBO4 — replaced with
   TURBO_ROT_DIM (= QK_TURBO3_GROUP) and added static_assert that
   QK_TURBO4 == QK_TURBO3_GROUP to guard against future drift.

Closes TheTom#29

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
vonempalmeolmos pushed a commit to vonempalmeolmos/llama-cpp-turboquant that referenced this pull request Mar 29, 2026
fix: turbo4 SET_ROWS, tail-block truncation, constant coupling, stack overflow (Issue TheTom#29)
seanrasch added a commit to seanrasch/llama-cpp-turboquant that referenced this pull request Mar 31, 2026
1. turbo_init_rotation() allocated float G[128*128] (64KB) on the stack
   then memcpy'd into the static turbo_rotation array. This segfaults on
   llama.cpp worker threads with reduced stack sizes (512KB macOS, 64KB
   some Linux). Fix: generate the Gaussian matrix directly into
   turbo_rotation, eliminating both the stack allocation and the memcpy.

2. TURBO_D and QK_TURBO3_GROUP are defined separately but must always
   match (both represent the rotation group size). Add static_assert to
   catch silent divergence between CPU reference and GPU kernels.

Fixes: TheTom#29 (remaining items from PR TheTom#18 review)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@seanrasch
Copy link
Copy Markdown
Author

Closing — this is now fully covered by the merged work on feature/turboquant-kv-cache. Specifically:

Thanks for merging these upstream — no action needed here.

@seanrasch seanrasch closed this Mar 31, 2026
mihai-chiorean pushed a commit to mihai-chiorean/turbo3-cuda that referenced this pull request Mar 31, 2026
Codex post-commit review found:
1. TURBO_D was QK_TURBO3 (now 32) — broke turbo4 C array sizes
2. SET_ROWS kernel turbo3-specific but instantiated for turbo4
3. Tail block drop for non-128 head dims

Fixed TheTom#3 (TURBO_D). #1 and TheTom#2 don't affect turbo3+dk128 path.

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
seanrasch added a commit to seanrasch/llama-cpp-turboquant that referenced this pull request Mar 31, 2026
1. turbo_init_rotation() allocated float G[128*128] (64KB) on the stack
   then memcpy'd into the static turbo_rotation array. This segfaults on
   llama.cpp worker threads with reduced stack sizes (512KB macOS, 64KB
   some Linux). Fix: generate the Gaussian matrix directly into
   turbo_rotation, eliminating both the stack allocation and the memcpy.

2. TURBO_D and QK_TURBO3_GROUP are defined separately but must always
   match (both represent the rotation group size). Add static_assert to
   catch silent divergence between CPU reference and GPU kernels.

Fixes: TheTom#29 (remaining items from PR TheTom#18 review)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
TheTom added a commit that referenced this pull request Apr 2, 2026
Codex post-commit review found:
1. TURBO_D was QK_TURBO3 (now 32) — broke turbo4 C array sizes
2. SET_ROWS kernel turbo3-specific but instantiated for turbo4
3. Tail block drop for non-128 head dims

Fixed #3 (TURBO_D). #1 and #2 don't affect turbo3+dk128 path.

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
TheTom pushed a commit that referenced this pull request Apr 2, 2026
…ling (Issue #29)

Three bugs from the block-size-32 refactor:

1. kernel_set_rows_turbo hardcoded turbo3 packing for turbo4 — split into
   separate kernel_set_rows_turbo3 and kernel_set_rows_turbo4 kernels.
   turbo4 now correctly does 3-bit PolarQuant + QJL residual correction.

2. Integer division in n_groups = nk0 / blocks_per_group silently dropped
   tail blocks for non-128-aligned head dims (e.g. dk=192). Added ceiling
   division with tail-group bounds checking in turbo3, and GGML_ASSERT in
   WHT dispatch to catch non-128-aligned tensors.

3. TURBO_D constant was semantically coupled to QK_TURBO4 — replaced with
   TURBO_ROT_DIM (= QK_TURBO3_GROUP) and added static_assert that
   QK_TURBO4 == QK_TURBO3_GROUP to guard against future drift.

Closes #29

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
TheTom added a commit that referenced this pull request Apr 2, 2026
Codex post-commit review found:
1. TURBO_D was QK_TURBO3 (now 32) — broke turbo4 C array sizes
2. SET_ROWS kernel turbo3-specific but instantiated for turbo4
3. Tail block drop for non-128 head dims

Fixed #3 (TURBO_D). #1 and #2 don't affect turbo3+dk128 path.

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
TheTom pushed a commit that referenced this pull request Apr 2, 2026
…ling (Issue #29)

Three bugs from the block-size-32 refactor:

1. kernel_set_rows_turbo hardcoded turbo3 packing for turbo4 — split into
   separate kernel_set_rows_turbo3 and kernel_set_rows_turbo4 kernels.
   turbo4 now correctly does 3-bit PolarQuant + QJL residual correction.

2. Integer division in n_groups = nk0 / blocks_per_group silently dropped
   tail blocks for non-128-aligned head dims (e.g. dk=192). Added ceiling
   division with tail-group bounds checking in turbo3, and GGML_ASSERT in
   WHT dispatch to catch non-128-aligned tensors.

3. TURBO_D constant was semantically coupled to QK_TURBO4 — replaced with
   TURBO_ROT_DIM (= QK_TURBO3_GROUP) and added static_assert that
   QK_TURBO4 == QK_TURBO3_GROUP to guard against future drift.

Closes #29

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant