Skip to content

HIP/ROCm: two crash fixes for TurboQuant KV cache on RDNA #4

Merged
Ooooze merged 2 commits into
AtomicBot-ai:feature/turboquant-kv-cachefrom
dedesite:fix-hip-crash
May 7, 2026
Merged

HIP/ROCm: two crash fixes for TurboQuant KV cache on RDNA #4
Ooooze merged 2 commits into
AtomicBot-ai:feature/turboquant-kv-cachefrom
dedesite:fix-hip-crash

Conversation

@dedesite

@dedesite dedesite commented May 7, 2026

Copy link
Copy Markdown

Overview

This PR fixes a linking problem and a runtime crash when using mtp models like Gemma 4 assistant.

Tested on a Ryzen AI HX 470 (gfx1150, RDNA3.5) running Gemma 4 E4B with:

./build/bin/llama-server \
    -m         ./models/gemma-4-E4B-it-Q4_K_M.gguf \
    --mtp-head ./models/gemma-4-E4B-it-assistant.Q4_K_M.gguf \
    --spec-type mtp \
    --draft-block-size 3 --draft-max 8 --draft-min 0 \
    -ngl 99 -ngld 99 \
    -ctk turbo3 -ctv turbo3 -ctkd turbo3 -ctvd turbo3 \
    -fa on -c 16384 --host 127.0.0.1 --port 8080

OS : Linux Mint 22.3 (based on Ubuntu 20.04.04) with ROCm 7.2.1 installed

Fix 1 — HIP linker error: missing fattn-vec template instances

ggml/src/ggml-hip/CMakeLists.txt was missing three cross-type flash-attention VEC instances (f16 key × turbo2/3/4 value) that were already present in ggml/src/ggml-cuda/CMakeLists.txt.

This produced link errors at the final llama-server link step:

  undefined reference to void ggml_cuda_flash_attn_ext_vec_case<                                                                                                                        
      256, (ggml_type)1, (ggml_type)42>                                                                                                                                                 
  undefined reference to void ggml_cuda_flash_attn_ext_vec_case<                                                                                                                        
      256, (ggml_type)1, (ggml_type)43>                                                                                                                                                 
  undefined reference to void ggml_cuda_flash_attn_ext_vec_case<                                                                                                                        
      256, (ggml_type)1, (ggml_type)44>                                                                                                                                                 

Fix: added the three files to the HIP CMake list.

Fix 2 — Runtime GGML_ABORT in fattn-tile.cuh for head_dim=512

Gemma 4 E4B has head_dim = 4096 / 8 = 512. For head_dim=512, all fast FA paths are excluded on AMD:

  • VEC — capped at head_dim ≤ 256
  • WMMA — explicitly excludes Q->ne[0] == 512
  • MFMA — explicitly excludes Q->ne[0] == 512

So TILE is always selected. Inside launch_fattn_tile_switch_ncols2<512, 512>, the DKQ ≤ 512 block only handled gqa_ratio % 4 == 0 and gqa_ratio % 8 == 0, then a DV ≤ 256 guard for smaller ratios. For DV=512 with gqa_ratio=2 (Gemma 4: 8 Q-heads / 4 KV-heads) the code fell straight through to GGML_ABORT("fatal error").
Fix:

  1. Dispatch — added ncols2=2 (for gqa_ratio % 2 == 0) and ncols2=1 (unconditional fallback) inside the DKQ ≤ 512 block after the DV ≤ 256 guard, mirroring the pattern that already exists for DV ≤ 256.
  2. Kernel configs — added (512, 512, ncols=2, nthreads=64, occupancy=2, nbatch_fa=32, nbatch_K=64) to all four config tables (nvidia_fp16, nvidia_fp32, amd, amd_rdna). Without these entries the device-side static_assert inside flash_attn_tile<512,512,1,2> would fire at compile time.

What fix this PR

Before fix 1: linker error, binary not produced.
Before fix 2: crash during first decode step with fattn-tile.cuh:1263: fatal error.
After both fixes: server runs without crash, MTP speculative decoding functional.

Test procedure

Build

cmake -B build -DGGML_HIP=ON -DAMDGPU_TARGETS=gfx1150 -DCMAKE_BUILD_TYPE=Release                                                                                                      
cmake --build build --target llama-server -j$(nproc)                                                                                                                                  ```

### Start server

```bash                                                                                                                                                     
./build/bin/llama-server \
    -m         ./models/gemma-4-E4B-it-Q4_K_M.gguf \
    --mtp-head ./models/gemma-4-E4B-it-assistant.Q4_K_M.gguf \
    --spec-type mtp \
    --draft-block-size 3 --draft-max 8 --draft-min 0 \
    -ngl 99 -ngld 99 \
    -ctk turbo3 -ctv turbo3 -ctkd turbo3 -ctvd turbo3 \
    -fa on -c 16384 --host 127.0.0.1 --port 8080

Benchmark

Here is a table with all tests result made.
Details of the command use to launch the server :
Baseline - Standard llama.cpp from llamacpp-rocm : ./llama-server -m ../atomic-llama-cpp-turboquant/models/gemma-4-E4B-it-Q4_K_M.gguf -ngl 99 -ngld 99 -fa on -c 16384 --host 127.0.0.1 --port 8080
KV Cache + MTP-HEAD : ./build/bin/llama-server -m ./models/gemma-4-E4B-it-Q4_K_M.gguf --mtp-head ./models/gemma-4-E4B-it-assistant.Q4_K_M.gguf --spec-type mtp --draft-block-size 3 --draft-max 8 --draft-min 0 -ngl 99 -ngld 99 -ctk turbo3 -ctvd turbo3 -fa on -c 16384 --host 127.0.0.1 --port 8080
MTP-HEAD Only : ./build/bin/llama-server -m ./models/gemma-4-E4B-it-Q4_K_M.gguf --mtp-head ./models/gemma-4-E4B-it-assistant.Q4_K_M.gguf --spec-type mtp --draft-block-size 3 --draft-max 8 --draft-min 0 -ngl 99 -ngld 99 -fa on -c 16384 --host 127.0.0.1 --port 8080;
KV Cache Only : ./build/bin/llama-server -m ./models/gemma-4-E4B-it-Q4_K_M.gguf -ngl 99 -ngld 99 -ctk turbo3 -ctvd turbo3 -fa on -c 16384 --host 127.0.0.1 --port 8080;

All test are runs with :

PORT=8080 PARALLEL=1 N_PREDICT=200 ./scripts/bench-parallel.sh
PORT=8080 PARALLEL=4 N_PREDICT=200 ./scripts/bench-parallel.sh
Type PARALLEL=1 (per_seq_tps) Speed (compare to baseline) PARALLEL=4 (per_seq_tps) Speed (compare to baseline)
Baseline 25.15 x1.0 14.10 x1.0
KV Cache + MTP Head 30.84 x1.22 18.96 x1.34
MTP Head Only 32.51 x1.29 18.52 x1.31
KV Cache Only 23.69 x0.94 13.03 x0.92

Requirements

  • I have read and agree with the contributing guidelines
  • AI usage disclosure: YES, to be honest the dev has entirely be done by Claude Code, I never worked on llama.cpp code, just wanted to run gemma with mtp assistant on my machine. If this doesn't match llama.cpp requirement then at least the code will be somewhere public.

Andreas Livet and others added 2 commits May 7, 2026 13:59
The HIP fattn-vec build list was missing three cross-type instances
(f16 key + turbo2/3/4 value) that were already present in the CUDA
CMakeLists.  This caused linker errors of the form:

  undefined reference to void ggml_cuda_flash_attn_ext_vec_case<
      256, (ggml_type)1, (ggml_type)42/43/44>

when building llama-server with GGML_HIP=ON and TurboQuant KV cache
enabled.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Models with head_dim=512 (e.g. Gemma 4 E4B: n_embd=4096, n_head=8)
always use the TILE flash-attention path on AMD/HIP because VEC is
capped at head_dim<=256 and WMMA/MFMA explicitly exclude D=512.

Inside launch_fattn_tile_switch_ncols2<512,512>, the DKQ<=512 block
only had fallback cases for gqa_ratio divisible by 4 or 8, then a
DV<=256 guard for ratio=2/1.  For DV=512 with gqa_ratio=2 (Gemma 4:
8 Q-heads / 4 KV-heads) the code fell through to GGML_ABORT.

Fix two things:
1. Dispatch: add ncols2=2 and ncols2=1 fallbacks inside the DKQ<=512
   block for the DV>256 case, mirroring what already exists for DV<=256.
2. Kernel configs: add the missing ncols=2 entry for DKQ=DV=512 in all
   four config tables (nvidia_fp16, nvidia_fp32, amd, amd_rdna).
   Without these entries the device-side static_assert would fire at
   compile time for flash_attn_tile<512,512,{1,2},2,*>.

Tested on gfx1150 (Ryzen AI HX 470, RDNA3.5) running Gemma 4 E4B
with -ctk turbo3 -ctv turbo3 and --mtp-head speculative decoding.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@Ooooze Ooooze merged commit 2e81dc5 into AtomicBot-ai:feature/turboquant-kv-cache May 7, 2026
1 check passed
@Ooooze

Ooooze commented May 7, 2026

Copy link
Copy Markdown

Thanks for the fix and the detailed writeup — much appreciated!

@dedesite

dedesite commented May 9, 2026 via email

Copy link
Copy Markdown
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants