Skip to content

fix: CUDA warp-to-block mapping for block_size=128#32

Merged
TheTom merged 1 commit intoTheTom:feature/turboquant-kv-cachefrom
HyperionMS2040:fix/cuda-block-size-128
Mar 30, 2026
Merged

fix: CUDA warp-to-block mapping for block_size=128#32
TheTom merged 1 commit intoTheTom:feature/turboquant-kv-cachefrom
HyperionMS2040:fix/cuda-block-size-128

Conversation

@HyperionMS2040
Copy link
Copy Markdown

Summary

The block_size=128 change (adac2c6) broke CUDA quantization in set-rows.cu. With QK_TURBO3=128, blocks_per_group = 1, but the warp-cooperative packing still computed blk = blk_base + warp_id — warps 1-3 wrote qs, signs, and norm out of bounds, corrupting adjacent KV cache memory.

Short llama-bench runs (pp512, tg128) could appear to pass because the OOB writes don't immediately affect the active attention window. llama-perplexity over full WikiText-2 produces all-NaN or segfaults.

Fix

Compute element position within the block generically:

const int elem_in_block = j % QK_TURBO3;
block_turbo3_0 * blk = blk_base + (j / QK_TURBO3);
  • elem_in_block / 4 for qs byte offset (range 0..31 for QK=128, 0..7 for QK=32)
  • elem_in_block / 8 for signs byte offset (range 0..15 for QK=128, 0..3 for QK=32)
  • elem_in_block == 0 for norm write gate (one per block)

Backward compatible — produces identical results with QK=32. Same fix applied to k_set_rows_turbo2.

Validation

RTX 3090 (sm_86), llama3.1:8b Q4_K_M, q8_0/turbo3, WikiText-2, 512 context:

Condition Block size PPL
spiritbuun build (pre-change baseline) 32 7.587
This branch, block_size reverted to 32 32 7.587
This branch, block_size=128 with fix 128 7.587

Zero deviation across all three conditions. The 5.12x compression ratio is now validated on CUDA Ampere.

Context

@seanrasch independently reported the same crash on SM 86 in Discussion ggml-org#20969.

The block_size=128 change (adac2c6) broke CUDA quantization:
with QK=128, blocks_per_group=1, but the warp-cooperative packing
still used blk_base+warp_id, causing warps 1-3 to write OOB.

Fix: compute elem_in_block = j % QK_TURBO_N and use it for block
pointer (j / QK_TURBO_N) and byte offsets (elem_in_block / 4 for qs,
elem_in_block / 8 for signs). Works for both QK=32 and QK=128.

Validated on RTX 3090 (sm_86), llama3.1:8b Q4_K_M, q8_0/turbo3:
PPL = 7.587 (matches QK=32 baseline exactly).
@TheTom
Copy link
Copy Markdown
Owner

TheTom commented Mar 30, 2026

Looking now

@TheTom TheTom merged commit 7b75078 into TheTom:feature/turboquant-kv-cache Mar 30, 2026
1 check passed
@TheTom
Copy link
Copy Markdown
Owner

TheTom commented Mar 30, 2026

Thank you for the contribution! Apologies for the regression

mihai-chiorean pushed a commit to mihai-chiorean/turbo3-cuda that referenced this pull request Mar 31, 2026
…n data

Part of TheTom#32: turbo3 prefill degrades relative to q8_0 with context length.

Changes so far:
- Skip ggml_cont when tensors already contiguous (+1%, minimal)
- Generated 32x32 rotation matrices (turbo-rotation-data-32.h) for
  reduced group size approach (16x less matmul compute)
- Fixed V un-rotation to check v->type not k->type

Next: update QK_TURBO3_GROUP, Metal WHT kernel, and KV cache for d=32.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: tturney@psyguard.ai
terrysimons pushed a commit to terrysimons/llama-cpp-turboquant that referenced this pull request Mar 31, 2026
…n data

Part of TheTom#32: turbo3 prefill degrades relative to q8_0 with context length.

Changes so far:
- Skip ggml_cont when tensors already contiguous (+1%, minimal)
- Generated 32x32 rotation matrices (turbo-rotation-data-32.h) for
  reduced group size approach (16x less matmul compute)
- Fixed V un-rotation to check v->type not k->type

Next: update QK_TURBO3_GROUP, Metal WHT kernel, and KV cache for d=32.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: tturney@psyguard.ai
TheTom added a commit that referenced this pull request Apr 2, 2026
…n data

Part of #32: turbo3 prefill degrades relative to q8_0 with context length.

Changes so far:
- Skip ggml_cont when tensors already contiguous (+1%, minimal)
- Generated 32x32 rotation matrices (turbo-rotation-data-32.h) for
  reduced group size approach (16x less matmul compute)
- Fixed V un-rotation to check v->type not k->type

Next: update QK_TURBO3_GROUP, Metal WHT kernel, and KV cache for d=32.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: tturney@psyguard.ai
TheTom added a commit that referenced this pull request Apr 2, 2026
…n data

Part of #32: turbo3 prefill degrades relative to q8_0 with context length.

Changes so far:
- Skip ggml_cont when tensors already contiguous (+1%, minimal)
- Generated 32x32 rotation matrices (turbo-rotation-data-32.h) for
  reduced group size approach (16x less matmul compute)
- Fixed V un-rotation to check v->type not k->type

Next: update QK_TURBO3_GROUP, Metal WHT kernel, and KV cache for d=32.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: tturney@psyguard.ai
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants