Skip to content

ggml: CUDA MXFP flash attention — Blackwell MMA + VEC kernels#2

Open
timothyeburke wants to merge 10 commits intomxfp-flash-attentionfrom
mxfp-flash-attention-cuda
Open

ggml: CUDA MXFP flash attention — Blackwell MMA + VEC kernels#2
timothyeburke wants to merge 10 commits intomxfp-flash-attentionfrom
mxfp-flash-attention-cuda

Conversation

@timothyeburke
Copy link
Copy Markdown

@timothyeburke timothyeburke commented Mar 16, 2026

Port MXFP flash attention to CUDA using the shared infrastructure from the mxfp-flash-attention branch. Three types supported: MXFP4 E2M1, MXFP8 E4M3, MXFP6 E2M3.

KV Cache Benchmark — gpt-oss-20b

Model: ggml-org/gpt-oss-20b/gpt-oss-20b-mxfp4.gguf
GPU: 5070 Ti 16GB
Backend: CUDA — 2026-03-22T22:53:44-04:00
Chunks: 16

K type V type K GiB V GiB Total PPL Δ F16 pp512 tg128
mxfp6 mxfp4 3.58 2.43 6.01 276.93 -63.32 8,566 213.5
mxfp6 mxfp6 3.58 3.58 7.16 292.95 -47.30 8,561 214.1
q8_0 q8_0 4.86 4.86 9.72 331.10 -9.15 8,719 217.5
mxfp8 mxfp8 4.72 4.72 9.44 334.37 -5.88 8,607 216.4
q8_0 q4_0 4.86 2.57 7.43 334.91 -5.34 8,681 217.1
mxfp8 mxfp4 4.72 2.43 7.15 337.08 -3.17 8,597 219.8
f16 f16 9.16 9.16 18.32 340.25 - 8,730 218.4
mxfp4 mxfp4 2.43 2.43 4.86 396.77 +56.52 8,560 214.5
q4_0 q4_0 2.57 2.57 5.14 737.20 +396.95 8,586 217.5

Memory: GiB per 100K tokens (K + V cache)

Add CUDA kernels for MXFP Struct-of-Arrays KV cache write and read paths,
implementing the CPU scalar reference from ops.cpp exactly.

set_rows (KV cache write):
  - New k_set_rows_mxfp_soa kernel: one warp per 32-element block
  - Fused Hadamard + quantize matching quantize_row_mxfp*_soa_hadamard
  - MXFP4 nibble packing, MXFP6 6-bit packing via shared ggml_mxfp_pack_fp6x4
  - E8M0 via shared ggml_mxfp_e8m0_base_estimate (no MSE search)

flash attention (KV cache read):
  - VEC kernel: inline MXFP K dequant+dot and V dequant with multihead
    SoA support matching mxfp_row_ptr + mxfp_dequant_head exactly
  - MMA kernel: SoA-to-F16 pre-conversion with multihead head extraction
  - Q Hadamard+roundtrip in VEC kernel (warp-cooperative) and MMA
    pre-kernel, using shared ggml_mxfp_hadamard_32_inplace butterfly
  - Multihead detection: nb[2] == ggml_row_size(type, D), matching
    mxfp_kv_params_init

All per-element math calls shared GGML_MXFP_FUNC constructs from
ggml-common.h. Only warp shuffle operations (Hadamard, amax) are
CUDA-specific in mxfp-common.cuh.

Passes 51 SET_ROWS + 1836 FLASH_ATTN_EXT tests on RTX 5070 Ti (sm_120).
- Hoist E8M0 scale decode: elem pairs always share a block (elem is
  even), so compute scale once per pair instead of twice
- Hardware intrinsics (CUDA 12.8+): replace LUT lookups with single-
  instruction CVT for FP8 E4M3, FP4 E2M1, FP6 E2M3 dequant; add
  cuda_fp6.h include; portable fallback for older toolkits
- Inline V multihead dequant: read directly from multihead SoA layout
  using computed offsets, eliminating the per-KV-position buffer copy
- Dockerfile: BuildKit cache mount for incremental CUDA builds

Benchmark (gpt-oss-20b, 2x RTX 5070 Ti, tg128):
  mxfp8+mxfp8: 192 -> 219 t/s (+14.2%, now exceeds f16)
  mxfp6+mxfp6: 194 -> 204 t/s (+5.0%)
  mxfp8+mxfp4: 195 -> 206 t/s (+5.8%)
  mxfp4+mxfp4: 205 -> 206 t/s (+0.6%)
  PPL unchanged across all configs.
Extract reusable helpers into mxfp-common.cuh that exploit the invariant
that elem is always even in the VEC kernel dequant loops:

- mxfp4_extract_nibble_pair: branchless nibble extraction using
  shift = (pos >= 16) ? 4 : 0 instead of branchy pos < 16 logic.
  Both nibbles always in same half (low or high) since pos is even.

- mxfp6_unpack_pair: single ggml_mxfp_unpack_fp6x4 call for both
  elements. pos0 even → pos0%4 ∈ {0,2} → grp0 == grp1 always,
  so the group-boundary branch is provably dead code. Eliminated.

Applied to all K dot, V dequant, and inline V multihead paths.
Net -39 lines (83 added, 122 removed).

Benchmark (gpt-oss-20b, 2x RTX 5070 Ti, tg128):
  mxfp8+mxfp4: 206 -> 219 t/s (exceeds f16 218.6)
  mxfp4+mxfp4: 206 -> 214 t/s (-2.0% vs f16)
  mxfp6+mxfp4: 197 -> 214 t/s (-2.1% vs f16)
  mxfp6+mxfp6: 204 -> 214 t/s (-2.2% vs f16)
Extract the repeated pattern of "unpack + scale + dequant a pair of
SoA elements" into a single template function in mxfp-common.cuh,
specialized for MXFP4, MXFP8, and MXFP6.

Each specialization encapsulates:
  - Type-specific data extraction (nibble pair / byte pair / fp6 unpack)
  - E8M0 scale decode (full or half depending on intrinsic availability)
  - Hardware intrinsic vs portable LUT fallback (#if CUDART_VERSION)
  - Returns float2{elem0, elem1}

This collapses 6 call sites (K dot, K inline, V half2, V float2,
vec_dot_mxfp4/8/6) from ~25 lines each to a single function call.
All 16 #if CUDART_VERSION guards in fattn-vec.cuh eliminated (now 0).
Add q8_0+q4_0 mixed-type FA to the default CUDA build (previously
required GGML_CUDA_FA_ALL_QUANTS). This is a common high-quality
config: q8_0 keys preserve attention pattern fidelity while q4_0
values save memory with minimal quality loss.

Changes:
- fattn.cu: add q8_0+q4_0 to default VEC dispatch and mixed-type guard
- CMakeLists.txt: compile fattn-vec-instance-q8_0-q4_0.cu by default

Before: pp=616, tg=125 (no FA, CPU fallback)
After:  pp=8691, tg=217 (full FA, GPU-accelerated)
@timothyeburke timothyeburke force-pushed the mxfp-flash-attention-cuda branch from a53a13a to e45b51e Compare March 23, 2026 00:26
@timothyeburke timothyeburke marked this pull request as ready for review March 23, 2026 00:28
Copilot AI review requested due to automatic review settings March 23, 2026 00:28
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Ports MXFP (MXFP4/MXFP6/MXFP8) flash attention to CUDA, adding SoA KV cache handling, vector-kernel support for decode, and MMA F16 pre-conversion plumbing for prefill.

Changes:

  • Add MXFP SoA set-rows quantization kernel + shared CUDA helpers (Hadamard, warp amax, pack/unpack, intrinsic dequant).
  • Extend FA VEC kernel and common utilities to support MXFP SoA (including mixed K/V MXFP8/MXFP6 with MXFP4 V) and mixed Q8_0 K with Q4_0 V selection.
  • Update build plumbing (template instances, CUDA headers, Docker build caching) to compile new kernels.

Reviewed changes

Copilot reviewed 14 out of 14 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
ggml/src/ggml-cuda/vendors/cuda.h Adds CUDA FP6 header for 12.8+ intrinsic support.
ggml/src/ggml-cuda/template-instances/fattn-vec-instance-*.cu Adds explicit template instantiations for MXFP VEC FA variants.
ggml/src/ggml-cuda/set-rows.cu Adds MXFP SoA set_rows CUDA kernel and dispatch.
ggml/src/ggml-cuda/mxfp-common.cuh Introduces CUDA-side warp helpers + MXFP SoA dequant/conversion kernels.
ggml/src/ggml-cuda/ggml-cuda.cu Declares set_rows support for MXFP dst types.
ggml/src/ggml-cuda/fattn.cu Enables MXFP/mixed-type paths in kernel selection and dispatch cases.
ggml/src/ggml-cuda/fattn-vec.cuh Implements MXFP-aware Q preprocessing, K/V SoA dequant, and mixed K/V logic.
ggml/src/ggml-cuda/fattn-common.cuh Adds MXFP SoA dot/dequant helpers; adds MXFP SoA→F16 + Q roundtrip for MMA path.
ggml/src/ggml-cuda/CMakeLists.txt Adds globs for new VEC template instance translation units.
.devops/cuda-new.Dockerfile Improves build caching + enables MXFP variants in CI image build.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

…s (-66 lines)

Consolidate all MXFP type-specific logic into mxfp-common.cuh shared constructs:

- mxfp_type_traits<type>: compile-time emax_offset and qs_per_blk per type
- mxfp_compute_e8m0<type>(): shared E8M0 scale computation (was 3 copies)
- mxfp_quantize_roundtrip<type>(): shared quantize→dequant (was 3 copies)
- mxfp_hadamard_roundtrip<type>(): fused Q preprocessing (VEC: 32→1 line)
- mxfp_multihead_ptrs<type>(): shared multihead offset computation
- MXFP_ROW_BYTES_EXPLICIT(): multihead detection macro
- vec_dot_fattn_vec_KQ_mxfp<type>: unified K dot (was 3 functions)
- dequantize_V_mxfp_D<type>: unified V dequant (was 3 functions)
- Templatized k_mxfp_soa_to_f16 and k_mxfp_q_hadamard_roundtrip kernels

Hardware intrinsics on CUDA 12.8+:
- x2 paired dequant: __nv_cvt_fp8x2/fp4x2/fp6x2_to_halfraw2
- Quantize: __nv_cvt_float_to_fp8/fp6/fp4
- E8M0 decode: ggml_cuda_e8m0_to_fp32 (was portable ggml_mxfp_*)
- Replaced log2f SFU in compute_e8m0_scale with integer bit extraction
@timothyeburke timothyeburke force-pushed the mxfp-flash-attention-cuda branch from 52220eb to 43fcc52 Compare March 23, 2026 01:53
…cale convention

Replace CUDA-local mxfp4_quantize_elem (reimplemented decision boundaries) with
ggml_mxfp_float_to_fp4_e2m1 from ggml-common.h. Replace kvalues_mxfp4 LUT dequant
with ggml_mxfp_fp4_e2m1_to_float. All MXFP4 portable paths now use full E8M0 scale
and shared reference functions — no more half-scale convention in CUDA MXFP code.
…es (-185 lines)

Replace 9 explicit MXFP4/8/6 template specializations with if-constexpr
single-body templates. Add MXFP_DISPATCH macro for runtime type→template
dispatch, eliminating 3-way switch blocks in set-rows.cu and
mxfp-common.cuh host functions. Consolidate compiler-macro sites
(#if CUDART_VERSION) from 9 down to 4.
@timothyeburke timothyeburke force-pushed the mxfp-flash-attention-cuda branch from c0c76b3 to 69a6525 Compare March 23, 2026 02:30
timothyeburke and others added 2 commits March 22, 2026 22:48
… lines)

Replace 9 standalone mxfp{4,6,8}_{dequant,quantize,dequant_x2}_intrinsic
wrappers with direct intrinsic calls at their single call sites. Extract
shared halfraw→float conversion into two small helpers.
Three changes for MXFP flash attention on CUDA:

1. Remove Q quantization roundtrip — Q is now Hadamard-rotated only,
   not quantized/dequantized. Q is computed fresh every token and never
   stored, so injecting quantization noise was unnecessary and hurts PPL.
   Rotate-only beats f16 baseline on some models. Applied pre-kernel for
   both MMA and VEC paths (VEC no longer does internal Q rotation).

2. Add GGML_MXFP_NO_HADAMARD=1 runtime toggle — disables Hadamard
   rotation for all MXFP types via env var. Needed because Hadamard
   benefit is architecture-dependent (helps Qwen3, hurts gpt-oss SWA).

3. Add MXFP6_E3M2 and MXFP8_E5M2 element formats — the M2-mantissa
   MXFP subspecies with wider dynamic range but coarser mantissa.
   Full wiring: type enums, traits, CUDA quantize/dequant, SoA layout,
   set_rows, flash attention VEC dispatch, CLI args, CPU fallback.

Also fix int32 overflow in llama-perplexity --save-all-logits for
contexts >= 16K with large vocab models.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants