perf: FA kernel optimizations + auto-asymmetric KV + warp shuffle WHT#36
Open
signalnine wants to merge 7 commits intoTheTom:feature/turboquant-kv-cachefrom
Open
perf: FA kernel optimizations + auto-asymmetric KV + warp shuffle WHT#36signalnine wants to merge 7 commits intoTheTom:feature/turboquant-kv-cachefrom
signalnine wants to merge 7 commits intoTheTom:feature/turboquant-kv-cachefrom
Conversation
When user requests symmetric turbo K+V on a quantized model, auto- downgrade K to q8_0 while keeping V as turbo. This prevents catastrophic PPL on outlier-sensitive models (Qwen 2.5: 4015→8.85) and actually improves quality on all models tested (Qwen 3.5: 6.31→6.24). V compression is virtually lossless across all architectures (+0.3%). K compression is model-sensitive and compounds with weight quantization error. Asymmetric q8_0-K + turbo-V is the safe default. Detection: checks tok_embd tensor type. If quantized (Q2-Q6), auto- switches. F16/F32/BF16 models keep symmetric turbo (no stacking risk). Override: TURBO_SYMMETRIC=1 forces symmetric. PPL (wikitext-2, ctx=512): Qwen 2.5 Q2_K: 4015 symmetric → 8.85 auto-asymmetric (+0.3% vs q8_0) Qwen 3.5 Q4_K_M: 6.31 symmetric → 6.24 auto-asymmetric (+1.0% vs q8_0) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Replace shared-memory butterfly stages h=1,2,4,8,16 with __shfl_xor_sync() in turbo3, turbo2, and turbo4 SET_ROWS kernels. Pairs with distance < 32 are always in the same warp — no barrier needed. Only h=32 and h=64 (cross-warp) retain __syncthreads(). Mathematically identical: __shfl_xor_sync(mask, v, h) gives thread j the value from j^h, then (j & h) ? (other - v) : (v + other) is the exact butterfly computation. Zero PPL regression: turbo3 6.24, turbo4 6.23 (unchanged). Credit to seanrasch (perf/ftz-and-wht-shuffle branch). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Auto-asymmetric (K→q8_0) creates FA dimension mismatch on models where head_dim_k != head_dim_v (e.g. DeepSeek: K=192, V=128). The q8_0 K at non-standard D has no CUDA FA kernel, falling to slow CPU FA. Fix: when K and V head dims differ, keep symmetric turbo which pads both K and V consistently for CUDA FA. Note: turbo4 symmetric on DeepSeek still falls to CPU FA (21 t/s) because padded K D=256 != V D=128. This is a pre-existing FA D matching limitation, not a regression. turbo3 on DeepSeek works at full CUDA speed (172 t/s) because the VEC kernel handles asymmetric D natively. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
b00a40b to
cfb57af
Compare
PGCRT
pushed a commit
to PGCRT/llama-cpp-turboquant-cuda
that referenced
this pull request
Apr 1, 2026
Move Q forward rotation from graph-level ggml_turbo_wht op into FA kernels to eliminate a separate kernel launch per layer during decode: - Vec kernel (decode): shared memory FWHT with 64-thread parallel butterfly, zero extra kernel launches, CUDA graph compatible - Prefill MMA: separate k_turbo_fwht_forward kernel with persistent cudaMalloc buffer (avoids cudaMallocAsync NaN on graph replay) - V inverse rotation remains at graph level for CUDA graph compat Results: decode 30.14 tok/s (-0.4%), prefill 1146 tok/s (-0.3%), PPL identical to baseline (19.7152 on 10-chunk test). Also adds temporal decay test (experiment TheTom#36) and benchmarks. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Port initial optimizations from Madreag/turbo3-cuda: - Remove turbo types from K_is_unquantized — Q is now q8_1 quantized (int8 packed) for turbo types, reducing Q register footprint 4× - Keep nthreads_KQ=8 for turbo (same ILP as f16) via K_is_turbo flag - Rewrite turbo3/turbo2/turbo4 vec_dot_KQ to process 4 elements per iteration with q8_1 Q (packed int32 + scale) - Replace 5 expf() with __expf() in softmax (~3.7% at long context) PPL unchanged (6.31). Speed unchanged on this model — the main speedup in Madreag's fork comes from a shared-memory LUT approach (precompute Q×centroid for all positions) which eliminates the multiply in the hot loop. That optimization needs separate porting. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Precompute Q[d] × centroid[c] into shared memory LUT once per decode step, then the KQ inner loop does a single shmem read per element instead of centroid lookup + multiply. - turbo3: 127→146 t/s at 32K depth (+15%) - turbo2: 148→154 t/s at 32K depth (+4%) - Only for ncols==1 (decode path, not prefill) - turbo4 excluded: 16 centroids × D exceeds shmem budget - LUT stride = n_centroids+1 to avoid bank conflicts - 8-wide processing (2 qs bytes + 1 signs byte per iteration) - L2 prefetch hints for next K block PPL unchanged (6.31). Zero quality impact — the LUT stores the exact same Q×centroid values, just precomputed. Based on Madreag/turbo3-cuda (release/cuda-optimized branch). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Four general VEC flash attention optimizations: 1. __launch_bounds__ occupancy 1→3: allows 3 blocks per SM, better latency hiding. This is the biggest win (~16% on q8_0 baseline). 2. V_is_unquantized: remove turbo types from V unquantized path, matching the K change. Turbo V uses quantized dequant path. 3. Aggressive sparse V threshold: 5e-3 for turbo3/4 (was 1e-6), 1e-2 for turbo2. Validated zero PPL impact per Madreag. 4. L2 prefetch: add V block prefetch alongside existing K prefetch. Results (tg128 @ d32768, Qwen3.5 35B Q4_K_M, RTX 5090): | Type | Before | After | vs q8_0 | |------|--------|-------|---------| | q8_0 | 156 | 181 | baseline | | turbo2 | 148 | 185 | +2.5% faster | | turbo3 | 127 | 171 | -5.5% | | turbo4 | 101 | 117 | -35% (no LUT) | turbo2 at 7.5x compression is now FASTER than q8_0 at 32K context. PPL unchanged (6.31). Based on Madreag/turbo3-cuda optimizations. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
338990c to
f575431
Compare
- ops.cpp: add TURBO3_0/TURBO4_0/TURBO2_0 to clamp switch to fix -Werror=switch on GCC/Clang CI - ggml-rpc.h: bump RPC_PROTO_PATCH_VERSION 1→2 and update GGML_OP_COUNT assert 96→97 (GGML_OP_TURBO_WHT added) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
CUDA flash attention kernel optimizations, auto-asymmetric KV for
quantized models, and SET_ROWS warp shuffle WHT. Combined effect:
turbo2 now beats q8_0 decode speed at 32K context.
Performance (tg128 @ d32768, Qwen3.5 35B Q4_K_M, RTX 5090)
Changes
1. Shared-memory LUT for turbo KQ scoring (+15% turbo3, +4% turbo2)
Precompute
Q[d] × centroid[c]into shared memory once per decodestep. The KQ inner loop does a single shmem read per element instead
of centroid lookup + multiply. 8-wide processing with bank-conflict-
free stride. turbo4 excluded (16 centroids too large for shmem).
2. General VEC FA optimizations (+16% q8_0 baseline)
__launch_bounds__occupancy 1→3 (3 blocks/SM, better latency hiding)__expffast-math softmax (~3.7%)3. q8_1 Q path for turbo types
Turbo K types now use q8_1-quantized Q (int8 packed) instead of
float2, reducing Q register footprint 4×. KQ dot product rewritten
to process 4 elements per iteration with integer Q.
4. Auto-asymmetric KV for quantized models
When user requests symmetric turbo K+V on a quantized-weight model,
auto-downgrades K to q8_0. Prevents catastrophic PPL on outlier
models (Qwen 2.5: 4015→8.85). Skipped for mismatched K/V head dims.
Override:
TURBO_SYMMETRIC=1.5. Warp shuffle WHT (SET_ROWS)
Replace 5 of 7
__syncthreads()with__shfl_xor_sync()forintra-warp butterfly stages in all turbo SET_ROWS kernels.
Quality
PPL unchanged across all optimizations:
Test plan
Credit: FA optimizations based on Madreag/turbo3-cuda.
Warp shuffle WHT based on seanrasch/perf/ftz-and-wht-shuffle.
🤖 Generated with Claude Code