Skip to content

feat: HIP/ROCm support for turbo3/turbo2 (7900 XTX)#31

Merged
TheTom merged 1 commit intoTheTom:feature/turboquant-kv-cachefrom
apollosenvy:pr/rocm-hip-port
Mar 30, 2026
Merged

feat: HIP/ROCm support for turbo3/turbo2 (7900 XTX)#31
TheTom merged 1 commit intoTheTom:feature/turboquant-kv-cachefrom
apollosenvy:pr/rocm-hip-port

Conversation

@apollosenvy
Copy link
Copy Markdown

Summary

HIP/ROCm porting for the turbo3/turbo2 warp-cooperative kernels. Split from PR #5 per review feedback.

Single commit, minimal surface area:

  • HIP vendor header (hip.h): Added cudaMemcpyToSymbol/FromSymbol mappings. Fixed __shfl_sync, __shfl_xor_sync, __shfl_up_sync, __shfl_down_sync to support 3-arg calls (CUDA defaults width to warpSize). Added __ballot_sync -> __ballot with uint32_t cast.
  • HIP CMakeLists: Added turbo3/turbo2 FA template instances. Excluded D>=576 fattn-tile kernels (exceed HIP's 64KB local memory limit).

Test Results (AMD 7900 XTX, ROCm 7.1)

Model KV Type PPL pp128 t/s tg32 t/s
Qwen3.5-27B Q4_K_M turbo3 7.58 654 25.2
Mistral-Small-24B Q4_K_S turbo3 5.28 600 24.2

turbo3 at ~98% of F16 speed. Mistral-Small (head_dim=128) confirmed working.

What's NOT in this PR

  • Temporal decay (separate PR, needs Metal kernel)
  • Non-128 head_dim fallback changes (separate discussion)

🤖 Generated with Claude Code

Co-Authored-By: Claude Opus 4.6 (1M context) noreply@anthropic.com

…rnels

Port TheTom's warp-cooperative turbo3 SET_ROWS kernel and turbo2/turbo3
flash attention templates to HIP/ROCm (7900 XTX, gfx1100).

HIP vendor header fixes:
- Add cudaMemcpyToSymbol/FromSymbol -> hipMemcpyToSymbol/FromSymbol
- Add cudaMemcpyHostToDevice/DeviceToHost mappings
- Fix __shfl_sync, __shfl_xor_sync, __shfl_up_sync, __shfl_down_sync
  to support both 3-arg and 4-arg calls (CUDA allows defaulting width
  to warpSize, HIP macros required 4 args)
- Add __ballot_sync -> __ballot with uint32_t cast (HIP returns 64-bit
  on wave64 platforms, turbo code expects 32-bit)

HIP CMakeLists:
- Add turbo3 and turbo2 flash attention template instances (same files
  as CUDA CMakeLists, were missing from HIP build)

Tested: Mistral-Small-24B turbo3 PPL = 5.28 (+2.4% vs F16 baseline 5.16)
Previously showed catastrophic PPL ~15000 due to CPU quantize stub bug
(fixed by TheTom in 53f1298).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@TheTom
Copy link
Copy Markdown
Owner

TheTom commented Mar 30, 2026

Tested on M5 Max 128GB and M2 Pro 32GB (Metal). HIP-only changes, zero shared code — confirmed no Metal regressions.

M5 Max (Qwen3.5-35B-A3B Q8_0):

Config PPL Baseline pp512 t/s tg128 t/s
turbo3 6.1756 6.1756 (match) 2726 76.69
turbo4 6.1250 6.1250 (match)
q8_0/turbo4 2760 81.17

M2 Pro (Qwen2.5-7B Q4_K_M, asymmetric — correct config for Q4_K_M models):

Config PPL vs q8_0 pp512 t/s tg128 t/s
q8_0/q8_0 6.7938 baseline 334.91 32.51
q8_0/turbo4 6.8281 +0.5% 332.72 28.39

Clean on both platforms. Nice minimal PR — the variadic shuffle macros and the D>=576 exclusion are both sensible. Thanks for the clean split from PR #5, this is exactly the right approach.

Merging.

@TheTom TheTom merged commit 64dd362 into TheTom:feature/turboquant-kv-cache Mar 30, 2026
1 check passed
shtaylor pushed a commit to shtaylor/llama-cpp-turboquant that referenced this pull request Mar 30, 2026
…heTom#31

Block 128: PPL=165.6 (same as block 32)
Disabled Q rotation: PPL=165.6 (same)
Root cause: dynamic_cast fails for MoE hybrid memory context.
Q rotation and V inverse rotation never execute.

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
shtaylor pushed a commit to shtaylor/llama-cpp-turboquant that referenced this pull request Mar 30, 2026
…eTom#31 TheTom#30

ROOT CAUSE: pre-rotate-queries never executed because:
1. Q ne[0]=256 (GQA concatenated heads), rotation matrix ne[0]=128
2. mctx dynamic_cast failed for MoE hybrid memory

FIX: put inverse WHT rotation back in dequantize_full_block.
This is slower (10.7 tok/s vs 77.7) but produces CORRECT results.

PERPLEXITY RESULTS:
- f16:     6.121
- q8_0:    6.111
- q4_0:    6.142
- turbo3:  6.194 (+1.2% vs q8_0) ✅

The speed optimization (pre-rotate-queries) needs to be reimplemented
to work with GQA head layout and hybrid memory types.

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
mihai-chiorean pushed a commit to mihai-chiorean/turbo3-cuda that referenced this pull request Mar 31, 2026
…heTom#31

Block 128: PPL=165.6 (same as block 32)
Disabled Q rotation: PPL=165.6 (same)
Root cause: dynamic_cast fails for MoE hybrid memory context.
Q rotation and V inverse rotation never execute.

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
mihai-chiorean pushed a commit to mihai-chiorean/turbo3-cuda that referenced this pull request Mar 31, 2026
…eTom#31 TheTom#30

ROOT CAUSE: pre-rotate-queries never executed because:
1. Q ne[0]=256 (GQA concatenated heads), rotation matrix ne[0]=128
2. mctx dynamic_cast failed for MoE hybrid memory

FIX: put inverse WHT rotation back in dequantize_full_block.
This is slower (10.7 tok/s vs 77.7) but produces CORRECT results.

PERPLEXITY RESULTS:
- f16:     6.121
- q8_0:    6.111
- q4_0:    6.142
- turbo3:  6.194 (+1.2% vs q8_0) ✅

The speed optimization (pre-rotate-queries) needs to be reimplemented
to work with GQA head layout and hybrid memory types.

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
PGCRT pushed a commit to PGCRT/llama-cpp-turboquant-cuda that referenced this pull request Apr 1, 2026
- turbo4 K+V results on Qwen3.5-27B (-0.32% vs q8_0) and Qwen3-14B (+6.3%)
- Sparse V dequant benchmarks: MoE native dequant +10.9% at 8K
- Gemma-3 turbo3 results post-iSWA fix (+3.3%)
- KVLinC no-K-rotation negative result
- Speculative decoding negative result
- CUDA 13.2 compatibility verified
- Experiments TheTom#31, TheTom#39, TheTom#42, TheTom#45, ggml-org#49, ggml-org#50, ggml-org#51 status updates

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
TheTom added a commit that referenced this pull request Apr 2, 2026
…31

Block 128: PPL=165.6 (same as block 32)
Disabled Q rotation: PPL=165.6 (same)
Root cause: dynamic_cast fails for MoE hybrid memory context.
Q rotation and V inverse rotation never execute.

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
TheTom added a commit that referenced this pull request Apr 2, 2026
…#30

ROOT CAUSE: pre-rotate-queries never executed because:
1. Q ne[0]=256 (GQA concatenated heads), rotation matrix ne[0]=128
2. mctx dynamic_cast failed for MoE hybrid memory

FIX: put inverse WHT rotation back in dequantize_full_block.
This is slower (10.7 tok/s vs 77.7) but produces CORRECT results.

PERPLEXITY RESULTS:
- f16:     6.121
- q8_0:    6.111
- q4_0:    6.142
- turbo3:  6.194 (+1.2% vs q8_0) ✅

The speed optimization (pre-rotate-queries) needs to be reimplemented
to work with GQA head layout and hybrid memory types.

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
TheTom added a commit that referenced this pull request Apr 2, 2026
…31

Block 128: PPL=165.6 (same as block 32)
Disabled Q rotation: PPL=165.6 (same)
Root cause: dynamic_cast fails for MoE hybrid memory context.
Q rotation and V inverse rotation never execute.

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
TheTom added a commit that referenced this pull request Apr 2, 2026
…#30

ROOT CAUSE: pre-rotate-queries never executed because:
1. Q ne[0]=256 (GQA concatenated heads), rotation matrix ne[0]=128
2. mctx dynamic_cast failed for MoE hybrid memory

FIX: put inverse WHT rotation back in dequantize_full_block.
This is slower (10.7 tok/s vs 77.7) but produces CORRECT results.

PERPLEXITY RESULTS:
- f16:     6.121
- q8_0:    6.111
- q4_0:    6.142
- turbo3:  6.194 (+1.2% vs q8_0) ✅

The speed optimization (pre-rotate-queries) needs to be reimplemented
to work with GQA head layout and hybrid memory types.

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants