Skip to content

fix: turbo4 SET_ROWS, tail-block truncation, constant coupling, stack overflow (Issue #29)#4

Merged
TheTom merged 3 commits intoTheTom:feature/turboquant-kv-cachefrom
seanrasch:feature/turboquant-kv-cache
Mar 27, 2026
Merged

fix: turbo4 SET_ROWS, tail-block truncation, constant coupling, stack overflow (Issue #29)#4
TheTom merged 3 commits intoTheTom:feature/turboquant-kv-cachefrom
seanrasch:feature/turboquant-kv-cache

Conversation

@seanrasch
Copy link
Copy Markdown

Summary

Fixes all three bugs from Issue #29, plus a bonus stack overflow fix found during testing.

  • Bug 1: kernel_set_rows_turbo hardcoded turbo3 packing for turbo4 — split into separate kernel_set_rows_turbo3 and kernel_set_rows_turbo4 kernels. turbo4 now correctly does 3-bit PolarQuant + QJL residual correction.
  • Bug 2: Integer division n_groups = nk0 / blocks_per_group silently dropped tail blocks for non-128-aligned head dims. Added ceiling division with tail-group bounds checking, and GGML_ASSERT in WHT dispatch.
  • Bug 3: TURBO_D was semantically coupled to QK_TURBO4 — replaced with TURBO_ROT_DIM (= QK_TURBO3_GROUP) and added static_assert guard.
  • Bonus: turbo_init_rotation() allocated a 64KB float G[128*128] on the stack, causing segfault on llama.cpp worker threads. Eliminated by generating directly into the static turbo_rotation[] array.

Test plan

  • C code compiles clean (gcc -fsyntax-only)
  • Metal shader validated (kernel split, host_names preserved, turbo4 has all 5 quantize steps)
  • turbo4 Python tests: 17/17 passed (turboquant_plus, separate PR)
  • Metal end-to-end (needs Apple Silicon — I don't have it)
  • Note: CPU SET_ROWS has no turbo dispatch — turbo types only work on Metal/CUDA backends

Ampere benchmark data (spiritbuun's CUDA fork, RTX 3080 Ti 12GB)

First published Ampere numbers for TurboQuant. Qwen3 8B Q4_K_M:

Cache pp512 (t/s) tg128 (t/s) vs f16
f16 4970 126.7 1.00x
q8_0 4909 125.4 0.99x
turbo3 4498 101.3 0.90x pp / 0.80x tg

Note: spiritbuun's fork needed PR #1's pool size fix (+2 in llama-kv-cache.cpp) to avoid crash on turbo cache init.

Closes #29

🤖 Generated with Claude Code

seanrasch and others added 2 commits March 26, 2026 22:13
…ling (Issue TheTom#29)

Three bugs from the block-size-32 refactor:

1. kernel_set_rows_turbo hardcoded turbo3 packing for turbo4 — split into
   separate kernel_set_rows_turbo3 and kernel_set_rows_turbo4 kernels.
   turbo4 now correctly does 3-bit PolarQuant + QJL residual correction.

2. Integer division in n_groups = nk0 / blocks_per_group silently dropped
   tail blocks for non-128-aligned head dims (e.g. dk=192). Added ceiling
   division with tail-group bounds checking in turbo3, and GGML_ASSERT in
   WHT dispatch to catch non-128-aligned tensors.

3. TURBO_D constant was semantically coupled to QK_TURBO4 — replaced with
   TURBO_ROT_DIM (= QK_TURBO3_GROUP) and added static_assert that
   QK_TURBO4 == QK_TURBO3_GROUP to guard against future drift.

Closes TheTom#29

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…stack

turbo_init_rotation() allocated a 128x128 float array (64KB) on the stack
to generate the random Gaussian matrix, then memcpy'd it to the static
turbo_rotation[]. llama.cpp worker threads have reduced stack sizes,
causing segfault on first turbo4 quantize call.

Fix: generate directly into the static turbo_rotation[] array, eliminating
the intermediate stack allocation entirely. The Gram-Schmidt QR
decomposition already runs in-place on turbo_rotation[].

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@seanrasch
Copy link
Copy Markdown
Author

RTX 3080 Ti (Ampere, SM 8.6) Benchmark Results

First published Ampere numbers for TurboQuant. Tested on spiritbuun's CUDA fork with the PR #1 pool size fix applied.

Model: Qwen3 8B Q4_K_M | GPU: EVGA RTX 3080 Ti 12GB | Flash attn: on | Full offload: ngl=99

Cache pp512 (t/s) tg128 (t/s) vs f16 pp vs f16 tg Compression
f16 4970 126.7 1.00x 1.00x 1.0x
q8_0 4909 125.4 0.99x 0.99x 2.0x
turbo3 4498 101.3 0.90x 0.80x 4.9x
turbo4 2259 91.3 0.45x 0.72x 3.8x

Notes

  • turbo4 prefill hit (0.45x) matches spiritbuun's known MMA-disabled issue
  • turbo3 decode (0.80x) is notably worse than spiritbuun's 3090 results (~0.97x), suggesting Ampere pays more for the turbo3 dequant path in flash attention
  • turbo4 decode (0.72x) is more competitive relative to turbo3 on this hardware
  • spiritbuun's fork required the +2 pool size fix from PR fix: account for TurboQuant rotation tensors in KV cache context pool #1 to avoid crash on turbo cache init

TheTom added a commit that referenced this pull request Mar 27, 2026
Complete experiment log:
  #1  4-mag LUT:           15.1 at 8K (BEST, +38%)
  #2  Batched extract:     13.7 (+25%)
  #3  Inline FA block:     13.5 (I-cache pressure)
  #4  Deferred norm:       12.9 (loses ILP)
  #5  2-pair half2:        12.0 (ternary overhead)
  #6  Select chain:        11.9 (branches kill)
  #7  Bit-arithmetic:      11.6 (ALU too heavy)
  #8  FMA branchless:      11.4 (ALU still too heavy)
  #9  Named-reg ternary:   10.3 (branches worst)
  #10 Main (8-LUT):        10.95 (baseline)
  #11 Non-vec FA:          10.2 (wrong kernel)
  Ceiling:                 24.5 (no dequant)

Apple8 hardware truth:
  1 divergent constant read < 7 ALU ops (even with fma)
  Branches cost MORE than divergent constant reads
  Array indexing ALWAYS spills on Metal
  4 constant addresses is the sweet spot

The 4-mag LUT is the dequant-level ceiling on Apple Silicon.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: tturney@psyguard.ai
@TheTom
Copy link
Copy Markdown
Owner

TheTom commented Mar 27, 2026

thanks for tackling all three bugs from #29 plus the stack overflow catch. the turbo4 kernel split and tail-block ceiling division look right to me. will pull this down and test locally on M5 Max this afternoon to verify no regressions on turbo3 (PPL + NIAH + speed) before merging. the Ampere numbers from @spiritbuun fork are great to have too.

@TheTom
Copy link
Copy Markdown
Owner

TheTom commented Mar 27, 2026

Test Results — M5 Max 128GB, Qwen3.5-35B-A3B Q8_0

Build: Clean compile, no warnings.

turbo3 (regression check)

  • PPL: 6.1756 ± 0.330 — identical to main branch ✅
  • Decode: no regression ✅
  • PR does not touch the turbo3 code path (separate kernel_set_rows_turbo3 now)

turbo4

  • PPL: 6.3823 ± 0.345 (4.4% above q8_0 baseline of 6.111)
  • Decode-only (-p 0 -n 128): 42.68 tok/s ✅
  • End-to-end generate: prefill FA kernels for turbo4 were missing (pre-existing, not this PR). Added non-vec prefill kernel instantiations locally to test — turbo4 runs but PPL is elevated. Needs further investigation.

Code review

  • SET_ROWS kernel split: correct, turbo4 has proper 3-bit PolarQuant + QJL packing ✅
  • Tail-block ceiling division: correct fix for non-128-aligned head dims ✅
  • TURBO_DTURBO_ROT_DIM + static_assert: clean ✅
  • Stack overflow fix: correct, turbo_init_rotation() no longer allocates 64KB on stack ✅

Recommendation

Good to merge — turbo3 has zero regression risk and the fixes are correct. However, turbo4 needs more testing before recommending it for use. The elevated PPL and missing prefill kernels are pre-existing issues, not introduced by this PR.

Recommendation: stay on turbo3 for now. turbo4 end-to-end validation (prefill FA kernels, quality investigation) tracked separately.

@TheTom TheTom merged commit 065ef53 into TheTom:feature/turboquant-kv-cache Mar 27, 2026
TheTom added a commit that referenced this pull request Mar 27, 2026
…che"

This reverts commit 065ef53, reversing
changes made to 7d1bd95.
@TheTom
Copy link
Copy Markdown
Owner

TheTom commented Mar 27, 2026

⚠️ REVERTED — turbo3 regression found

During extended validation, discovered that PR #4 breaks turbo3 on Metal:

State turbo3 PPL (c=512, 8ch)
Pre-merge (7d1bd95) 6.1756
Post-merge (065ef53) 181.5955
Post-revert (a52586e) 6.1756

The kernel split from the shared kernel_set_rows_turbo template into separate kernel_set_rows_turbo3 / kernel_set_rows_turbo4 functions introduced a regression in the turbo3 quantization path. The q8_0 baseline is unaffected (5.46-6.03 across all context lengths), confirming this is turbo3-specific.

Merge reverted in a52586e. The turbo4 fixes and stack overflow fix are good — the issue is specifically in the turbo3 kernel split. Likely a subtle difference in the template parameters or quantization logic when moving from the shared template to the dedicated function.

Happy to help debug the turbo3 path if you want to resubmit.

Madreag pushed a commit to Madreag/turbo3-cuda that referenced this pull request Mar 28, 2026
* FlashAttention (TheTom#13)

* Add inplace softmax

* Move rms_norm to split row approach

* Update debug for supports_op

* clean up debug statements

* neg f16xf32xip builds and runs, havent actually ran a model that uses neg kernel yet though

* neg passes backend test

* unary operators pass ggml tests

* rms_norm double declaration bug atoned

* abides by editor-config

* removed vestigial files

* fixed autoconfig

* All operators (inlcluding xielu) working

* removed unnecesarry checking if node->src[1] exists for unary operators

* responded and dealt with PR comments

* implemented REPL_Template support and removed bug in unary operators kernel

* formatted embed wgsl and ggml-webgpu.cpp

* Faster tensors (TheTom#8)

Add fast matrix and matrix/vector multiplication.

* Use map for shader replacements instead of pair of strings

* Wasm (TheTom#9)

* webgpu : fix build on emscripten

* more debugging stuff

* test-backend-ops: force single thread on wasm

* fix single-thread case for init_tensor_uniform

* use jspi

* add pthread

* test: remember to set n_thread for cpu backend

* Add buffer label and enable dawn-specific toggles to turn off some checks

* Intermediate state

* Fast working f16/f32 vec4

* Working float fast mul mat

* Clean up naming of mul_mat to match logical model, start work on q mul_mat

* Setup for subgroup matrix mat mul

* Basic working subgroup matrix

* Working subgroup matrix tiling

* Handle weirder sg matrix sizes (but still % sg matrix size)

* Working start to gemv

* working f16 accumulation with shared memory staging

* Print out available subgroup matrix configurations

* Vectorize dst stores for sg matrix shader

* Gemv working scalar

* Minor set_rows optimization (TheTom#4)

* updated optimization, fixed errors

* non vectorized version now dispatches one thread per element

* Simplify

* Change logic for set_rows pipelines

---------

Co-authored-by: Neha Abbas <nehaabbas@macbookpro.lan>
Co-authored-by: Neha Abbas <nehaabbas@ReeseLevines-MacBook-Pro.local>
Co-authored-by: Reese Levine <reeselevine1@gmail.com>

* Comment on dawn toggles

* Working subgroup matrix code for (semi)generic sizes

* Remove some comments

* Cleanup code

* Update dawn version and move to portable subgroup size

* Try to fix new dawn release

* Update subgroup size comment

* Only check for subgroup matrix configs if they are supported

* Add toggles for subgroup matrix/f16 support on nvidia+vulkan

* Make row/col naming consistent

* Refactor shared memory loading

* Move sg matrix stores to correct file

* Working q4_0

* Formatting

* Work with emscripten builds

* Fix test-backend-ops emscripten for f16/quantized types

* Use emscripten memory64 to support get_memory

* Add build flags and try ci

---------

Co-authored-by: Xuan Son Nguyen <son@huggingface.co>

* Remove extra whitespace

* Move wasm single-thread logic out of test-backend-ops for cpu backend

* Disable multiple threads for emscripten single-thread builds in ggml_graph_plan

* Refactored pipelines and workgroup calculations (TheTom#10)

* refactored pipelines

* refactored workgroup calculation

* removed commented out block of prior maps

* Clean up ceiling division pattern

---------

Co-authored-by: Neha Abbas <nehaabbas@eduroam-169-233-141-223.ucsc.edu>
Co-authored-by: Reese Levine <reeselevine1@gmail.com>

* Start work on flash attention

* Shader structure set up (many bugs still)

* debugging

* Working first test

* Working with head grouping, head sizes to 128, logit softcap, mask/sinks enabled, f32

* Generalize softmax to work with multiple subgroups, f16 accumulation, mask shared memory tiling

* Start work on integrating pre-wgsl

* Separate structs/initial shader compilation library into separate files

* Work on compilation choices for flashattention

* Work on subgroup matrix/tile size portability

* subgroup size agnostic online softmax

* Cleanups, quantization types

* more cleanup

* fix wasm build

* Refactor flashattention to increase parallelism, use direct loads for KV in somce cases

* Checkpoint

* formatting

* Update to account for default kv cache padding

* formatting shader

* Add workflow for ggml-ci webgpu

* Try passing absolute path to dawn in ggml-ci

* Avoid error on device destruction, add todos for proper cleanup

* Fix unused warning

* Forgot one parameter unused

* Move some flashattn computation to f32 for correctness
aminya pushed a commit to aminya/llama-cpp-turboquant that referenced this pull request Mar 29, 2026
Add GGML_API to turbo quantize_row declarations in ggml-cpu/quants.h
to match ggml-quants.h. MSVC requires consistent __declspec linkage
when the same symbol is declared across compilation units.

Closes TheTom#4

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
TheTom added a commit that referenced this pull request Apr 2, 2026
Complete experiment log:
  #1  4-mag LUT:           15.1 at 8K (BEST, +38%)
  #2  Batched extract:     13.7 (+25%)
  #3  Inline FA block:     13.5 (I-cache pressure)
  #4  Deferred norm:       12.9 (loses ILP)
  #5  2-pair half2:        12.0 (ternary overhead)
  #6  Select chain:        11.9 (branches kill)
  #7  Bit-arithmetic:      11.6 (ALU too heavy)
  #8  FMA branchless:      11.4 (ALU still too heavy)
  #9  Named-reg ternary:   10.3 (branches worst)
  #10 Main (8-LUT):        10.95 (baseline)
  #11 Non-vec FA:          10.2 (wrong kernel)
  Ceiling:                 24.5 (no dequant)

Apple8 hardware truth:
  1 divergent constant read < 7 ALU ops (even with fma)
  Branches cost MORE than divergent constant reads
  Array indexing ALWAYS spills on Metal
  4 constant addresses is the sweet spot

The 4-mag LUT is the dequant-level ceiling on Apple Silicon.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: tturney@psyguard.ai
TheTom added a commit that referenced this pull request Apr 2, 2026
…che"

This reverts commit 065ef53, reversing
changes made to 7d1bd95.
TheTom added a commit that referenced this pull request Apr 2, 2026
Complete experiment log:
  #1  4-mag LUT:           15.1 at 8K (BEST, +38%)
  #2  Batched extract:     13.7 (+25%)
  #3  Inline FA block:     13.5 (I-cache pressure)
  #4  Deferred norm:       12.9 (loses ILP)
  #5  2-pair half2:        12.0 (ternary overhead)
  #6  Select chain:        11.9 (branches kill)
  #7  Bit-arithmetic:      11.6 (ALU too heavy)
  #8  FMA branchless:      11.4 (ALU still too heavy)
  #9  Named-reg ternary:   10.3 (branches worst)
  #10 Main (8-LUT):        10.95 (baseline)
  #11 Non-vec FA:          10.2 (wrong kernel)
  Ceiling:                 24.5 (no dequant)

Apple8 hardware truth:
  1 divergent constant read < 7 ALU ops (even with fma)
  Branches cost MORE than divergent constant reads
  Array indexing ALWAYS spills on Metal
  4 constant addresses is the sweet spot

The 4-mag LUT is the dequant-level ceiling on Apple Silicon.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: tturney@psyguard.ai
TheTom added a commit that referenced this pull request Apr 2, 2026
…che"

This reverts commit 065ef53, reversing
changes made to 7d1bd95.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants