ggml: CUDA MXFP flash attention — Blackwell MMA + VEC kernels#2
Open
timothyeburke wants to merge 10 commits intomxfp-flash-attentionfrom
Open
ggml: CUDA MXFP flash attention — Blackwell MMA + VEC kernels#2timothyeburke wants to merge 10 commits intomxfp-flash-attentionfrom
timothyeburke wants to merge 10 commits intomxfp-flash-attentionfrom
Conversation
Add CUDA kernels for MXFP Struct-of-Arrays KV cache write and read paths,
implementing the CPU scalar reference from ops.cpp exactly.
set_rows (KV cache write):
- New k_set_rows_mxfp_soa kernel: one warp per 32-element block
- Fused Hadamard + quantize matching quantize_row_mxfp*_soa_hadamard
- MXFP4 nibble packing, MXFP6 6-bit packing via shared ggml_mxfp_pack_fp6x4
- E8M0 via shared ggml_mxfp_e8m0_base_estimate (no MSE search)
flash attention (KV cache read):
- VEC kernel: inline MXFP K dequant+dot and V dequant with multihead
SoA support matching mxfp_row_ptr + mxfp_dequant_head exactly
- MMA kernel: SoA-to-F16 pre-conversion with multihead head extraction
- Q Hadamard+roundtrip in VEC kernel (warp-cooperative) and MMA
pre-kernel, using shared ggml_mxfp_hadamard_32_inplace butterfly
- Multihead detection: nb[2] == ggml_row_size(type, D), matching
mxfp_kv_params_init
All per-element math calls shared GGML_MXFP_FUNC constructs from
ggml-common.h. Only warp shuffle operations (Hadamard, amax) are
CUDA-specific in mxfp-common.cuh.
Passes 51 SET_ROWS + 1836 FLASH_ATTN_EXT tests on RTX 5070 Ti (sm_120).
- Hoist E8M0 scale decode: elem pairs always share a block (elem is even), so compute scale once per pair instead of twice - Hardware intrinsics (CUDA 12.8+): replace LUT lookups with single- instruction CVT for FP8 E4M3, FP4 E2M1, FP6 E2M3 dequant; add cuda_fp6.h include; portable fallback for older toolkits - Inline V multihead dequant: read directly from multihead SoA layout using computed offsets, eliminating the per-KV-position buffer copy - Dockerfile: BuildKit cache mount for incremental CUDA builds Benchmark (gpt-oss-20b, 2x RTX 5070 Ti, tg128): mxfp8+mxfp8: 192 -> 219 t/s (+14.2%, now exceeds f16) mxfp6+mxfp6: 194 -> 204 t/s (+5.0%) mxfp8+mxfp4: 195 -> 206 t/s (+5.8%) mxfp4+mxfp4: 205 -> 206 t/s (+0.6%) PPL unchanged across all configs.
Extract reusable helpers into mxfp-common.cuh that exploit the invariant
that elem is always even in the VEC kernel dequant loops:
- mxfp4_extract_nibble_pair: branchless nibble extraction using
shift = (pos >= 16) ? 4 : 0 instead of branchy pos < 16 logic.
Both nibbles always in same half (low or high) since pos is even.
- mxfp6_unpack_pair: single ggml_mxfp_unpack_fp6x4 call for both
elements. pos0 even → pos0%4 ∈ {0,2} → grp0 == grp1 always,
so the group-boundary branch is provably dead code. Eliminated.
Applied to all K dot, V dequant, and inline V multihead paths.
Net -39 lines (83 added, 122 removed).
Benchmark (gpt-oss-20b, 2x RTX 5070 Ti, tg128):
mxfp8+mxfp4: 206 -> 219 t/s (exceeds f16 218.6)
mxfp4+mxfp4: 206 -> 214 t/s (-2.0% vs f16)
mxfp6+mxfp4: 197 -> 214 t/s (-2.1% vs f16)
mxfp6+mxfp6: 204 -> 214 t/s (-2.2% vs f16)
Extract the repeated pattern of "unpack + scale + dequant a pair of
SoA elements" into a single template function in mxfp-common.cuh,
specialized for MXFP4, MXFP8, and MXFP6.
Each specialization encapsulates:
- Type-specific data extraction (nibble pair / byte pair / fp6 unpack)
- E8M0 scale decode (full or half depending on intrinsic availability)
- Hardware intrinsic vs portable LUT fallback (#if CUDART_VERSION)
- Returns float2{elem0, elem1}
This collapses 6 call sites (K dot, K inline, V half2, V float2,
vec_dot_mxfp4/8/6) from ~25 lines each to a single function call.
All 16 #if CUDART_VERSION guards in fattn-vec.cuh eliminated (now 0).
Add q8_0+q4_0 mixed-type FA to the default CUDA build (previously required GGML_CUDA_FA_ALL_QUANTS). This is a common high-quality config: q8_0 keys preserve attention pattern fidelity while q4_0 values save memory with minimal quality loss. Changes: - fattn.cu: add q8_0+q4_0 to default VEC dispatch and mixed-type guard - CMakeLists.txt: compile fattn-vec-instance-q8_0-q4_0.cu by default Before: pp=616, tg=125 (no FA, CPU fallback) After: pp=8691, tg=217 (full FA, GPU-accelerated)
a53a13a to
e45b51e
Compare
There was a problem hiding this comment.
Pull request overview
Ports MXFP (MXFP4/MXFP6/MXFP8) flash attention to CUDA, adding SoA KV cache handling, vector-kernel support for decode, and MMA F16 pre-conversion plumbing for prefill.
Changes:
- Add MXFP SoA set-rows quantization kernel + shared CUDA helpers (Hadamard, warp amax, pack/unpack, intrinsic dequant).
- Extend FA VEC kernel and common utilities to support MXFP SoA (including mixed K/V MXFP8/MXFP6 with MXFP4 V) and mixed Q8_0 K with Q4_0 V selection.
- Update build plumbing (template instances, CUDA headers, Docker build caching) to compile new kernels.
Reviewed changes
Copilot reviewed 14 out of 14 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| ggml/src/ggml-cuda/vendors/cuda.h | Adds CUDA FP6 header for 12.8+ intrinsic support. |
| ggml/src/ggml-cuda/template-instances/fattn-vec-instance-*.cu | Adds explicit template instantiations for MXFP VEC FA variants. |
| ggml/src/ggml-cuda/set-rows.cu | Adds MXFP SoA set_rows CUDA kernel and dispatch. |
| ggml/src/ggml-cuda/mxfp-common.cuh | Introduces CUDA-side warp helpers + MXFP SoA dequant/conversion kernels. |
| ggml/src/ggml-cuda/ggml-cuda.cu | Declares set_rows support for MXFP dst types. |
| ggml/src/ggml-cuda/fattn.cu | Enables MXFP/mixed-type paths in kernel selection and dispatch cases. |
| ggml/src/ggml-cuda/fattn-vec.cuh | Implements MXFP-aware Q preprocessing, K/V SoA dequant, and mixed K/V logic. |
| ggml/src/ggml-cuda/fattn-common.cuh | Adds MXFP SoA dot/dequant helpers; adds MXFP SoA→F16 + Q roundtrip for MMA path. |
| ggml/src/ggml-cuda/CMakeLists.txt | Adds globs for new VEC template instance translation units. |
| .devops/cuda-new.Dockerfile | Improves build caching + enables MXFP variants in CI image build. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
…s (-66 lines) Consolidate all MXFP type-specific logic into mxfp-common.cuh shared constructs: - mxfp_type_traits<type>: compile-time emax_offset and qs_per_blk per type - mxfp_compute_e8m0<type>(): shared E8M0 scale computation (was 3 copies) - mxfp_quantize_roundtrip<type>(): shared quantize→dequant (was 3 copies) - mxfp_hadamard_roundtrip<type>(): fused Q preprocessing (VEC: 32→1 line) - mxfp_multihead_ptrs<type>(): shared multihead offset computation - MXFP_ROW_BYTES_EXPLICIT(): multihead detection macro - vec_dot_fattn_vec_KQ_mxfp<type>: unified K dot (was 3 functions) - dequantize_V_mxfp_D<type>: unified V dequant (was 3 functions) - Templatized k_mxfp_soa_to_f16 and k_mxfp_q_hadamard_roundtrip kernels Hardware intrinsics on CUDA 12.8+: - x2 paired dequant: __nv_cvt_fp8x2/fp4x2/fp6x2_to_halfraw2 - Quantize: __nv_cvt_float_to_fp8/fp6/fp4 - E8M0 decode: ggml_cuda_e8m0_to_fp32 (was portable ggml_mxfp_*) - Replaced log2f SFU in compute_e8m0_scale with integer bit extraction
52220eb to
43fcc52
Compare
…cale convention Replace CUDA-local mxfp4_quantize_elem (reimplemented decision boundaries) with ggml_mxfp_float_to_fp4_e2m1 from ggml-common.h. Replace kvalues_mxfp4 LUT dequant with ggml_mxfp_fp4_e2m1_to_float. All MXFP4 portable paths now use full E8M0 scale and shared reference functions — no more half-scale convention in CUDA MXFP code.
…es (-185 lines) Replace 9 explicit MXFP4/8/6 template specializations with if-constexpr single-body templates. Add MXFP_DISPATCH macro for runtime type→template dispatch, eliminating 3-way switch blocks in set-rows.cu and mxfp-common.cuh host functions. Consolidate compiler-macro sites (#if CUDART_VERSION) from 9 down to 4.
c0c76b3 to
69a6525
Compare
… lines)
Replace 9 standalone mxfp{4,6,8}_{dequant,quantize,dequant_x2}_intrinsic
wrappers with direct intrinsic calls at their single call sites. Extract
shared halfraw→float conversion into two small helpers.
Three changes for MXFP flash attention on CUDA: 1. Remove Q quantization roundtrip — Q is now Hadamard-rotated only, not quantized/dequantized. Q is computed fresh every token and never stored, so injecting quantization noise was unnecessary and hurts PPL. Rotate-only beats f16 baseline on some models. Applied pre-kernel for both MMA and VEC paths (VEC no longer does internal Q rotation). 2. Add GGML_MXFP_NO_HADAMARD=1 runtime toggle — disables Hadamard rotation for all MXFP types via env var. Needed because Hadamard benefit is architecture-dependent (helps Qwen3, hurts gpt-oss SWA). 3. Add MXFP6_E3M2 and MXFP8_E5M2 element formats — the M2-mantissa MXFP subspecies with wider dynamic range but coarser mantissa. Full wiring: type enums, traits, CUDA quantize/dequant, SoA layout, set_rows, flash attention VEC dispatch, CLI args, CPU fallback. Also fix int32 overflow in llama-perplexity --save-all-logits for contexts >= 16K with large vocab models. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Port MXFP flash attention to CUDA using the shared infrastructure from the mxfp-flash-attention branch. Three types supported: MXFP4 E2M1, MXFP8 E4M3, MXFP6 E2M3.
KV Cache Benchmark — gpt-oss-20b
Model:
ggml-org/gpt-oss-20b/gpt-oss-20b-mxfp4.ggufGPU: 5070 Ti 16GB
Backend: CUDA — 2026-03-22T22:53:44-04:00
Chunks: 16
Memory: GiB per 100K tokens (K + V cache)