Skip to content

perf: D=512 quantized-KV FA vec kernels (gated gqa_ratio<=4) + server logprobs partial-sort#102

Merged
marksverdhei merged 4 commits into
htfrom
perf/fattn-vec-512-logprobs
Jun 12, 2026
Merged

perf: D=512 quantized-KV FA vec kernels (gated gqa_ratio<=4) + server logprobs partial-sort#102
marksverdhei merged 4 commits into
htfrom
perf/fattn-vec-512-logprobs

Conversation

@marksverdhei

@marksverdhei marksverdhei commented Jun 11, 2026

Copy link
Copy Markdown

Two perf changes from a fresh review of the hot paths, now fully validated on both fleet architectures (RTX 3090 sm_86, Quadro P5200 sm_61). The deployment-shape validation materially changed the story for change 1 — see below. Change 2 (server logprobs) stands as originally measured.

1. CUDA: FA vec kernel instances for D == 512 with matched quantized KV — now gated to gqa_ratio <= 4

Gemma 4 global-attention layers (head size 512) always dispatch to the tile (pre-Volta) or MMA (Turing+) FA kernels — both require K/V in F16, so with a quantized KV cache the entire per-layer global KV is dequantized into an F16 staging buffer on every decode step. The vec kernel reads q4_0/q8_0 directly but had no D=512 instances (only 64/128/256). This PR adds matched q4_0/q8_0 D=512 vec instances.

What deployment-shape validation found: the original perf cases used gqa_ratio = 4. The real Gemma 4 global layers are MQA (head_count_kv = 1, gqa_ratio = 16). At that shape the vec kernel re-reads the single KV head once per Q head (16×), while TILE/MMA pay the dequant staging once and amortize K/V reads across Q heads via the GQA optimization (8 Q-heads per pass). The vec kernel loses there — on both architectures:

Per-op, deployment shape (MQA, nh_kv=1, ratio 16, nb=1, q8_0):

kv sm_86 MMA (baseline) sm_86 vec (ungated) sm_61 TILE (baseline) sm_61 vec (ungated)
4096 58.5 us 64.9 us (−11%) 231.0 us 322.8 us (−40%)
8192 92.4 us 125.4 us (−36%) 404.2 us 590.3 us (−46%)
16384 158.7 us 238.0 us (−50%) 765.5 us 1165.7 us (−52%)

(q4_0 is worse still: up to −159% on sm_61.) End-to-end this was a measurable ~4% tg regression at d32768 on the 3090, reproduced in both run orders.

Per-op, gqa_ratio = 4 (the shape where vec wins), q8_0:

kv sm_86 MMA sm_86 vec speedup sm_61 TILE sm_61 vec speedup
4096 152.4 us 74.6 us 2.04x 522.4 us 383.1 us 1.36x
8192 291.2 us 143.2 us 2.03x 1014.8 us 795.0 us 1.28x
16384 575.2 us 286.2 us 2.01x 1990.5 us 1438.0 us 1.38x

(q4_0 at ratio 4: sm_86 keeps a 1.65–1.79x win; sm_61 is a wash, 1.00–1.04x — the in-kernel q4_0 dequant costs roughly what the staging pass saved on Pascal.)

Resolution (3c8a368): the D=512 vec dispatch requires gqa_ratio <= 4. The deployment shape falls back to the baseline kernels — verified per-op parity (≤2%) and e2e parity on both arches. Net effect for the current fleet: neutral on FA (no regression, no win), with the kernels + dispatch + test coverage in place for low-ratio D=512 shapes and as the substrate for the real fix.

Follow-up that would capture the actual prize: a quantized-KV-direct kernel with GQA packing (vec + ncols2, or TILE reading q8_0/q4_0 natively) would cut Gemma 4 global-layer FA traffic to ~1/4 of either current path — the per-op headroom at MQA-16 is ~4x. That is a kernel-engineering task, not a dispatch tweak.

2. Server: avoid full-vocab sort in get_token_probabilities

When a client requests logprobs (n_probs > 0), the server did a full std::sort of the entire vocabulary (262k entries for Gemma) per emitted token, then the caller linearly re-scanned the sorted vector to find the sampled token. Softmax normalization only needs max+sum (O(V), no ordering); top-k comes from std::partial_sort. O(V log V + V) → O(V + k log k) per token, identical output.

Validation

  • test-backend-ops test -o FLASH_ATTN_EXT vs CPU reference: 2903/2903 pass on sm_86 (RTX 3090) and full sweep pass on sm_61 (Quadro P5200) — including new hs=512 quantized-KV cases at both gqa_ratio=4 and the deployment MQA-16 shape (q8_0/q4_0, nb={1,2,3,32}, sinks variant). Quantized-KV FA previously had zero hs=512 coverage (only 64/72).

  • sm_61 run exercises the cols_per_block=2 vec instantiation (pre-Volta dispatch, ne1<=2); sm_86 exercises cols_per_block=1.

  • E2e llama-bench A/B (gemma-4-12B QAT Q4_K_XL, -fa 1 -ctk q8_0 -ctv q8_0): gated build ≡ baseline within noise on sm_86; ungated build reproduced the predicted regression (3090: 58.6→56.3 tg @d32k). sm_61 gated per-op parity confirmed (228/397/752 us vs 231/404/766 baseline).

  • sm_61 3-way e2e (amethyst, idle P5200, ngl=99, r=2 — stddev ≤0.31 throughout):

    tg32 @ depth baseline ungated vec gated (this PR)
    d0 24.70 24.35 (−1.4%) 24.54
    d4096 22.39 21.77 (−2.8%) 22.24
    d16384 20.36 18.88 (−7.3%) 20.30
    d32768 18.11 16.02 (−11.5%) 18.12 (parity)

    The ungated dispatch would have cost the Pascal fleet 11.5% decode at 32k depth; the gate restores exact parity.

  • Deployment-shape perf cases added to make_test_cases_perf so this dispatch decision stays measurable.

  • generate_cu_files.py updated so regeneration reproduces the instance files byte-for-byte.

Numbers provenance

  • sm_86: local RTX 3090 (desktop-loaded; per-op numbers are thousands-of-runs averages, e2e cross-checked in both run orders).
  • sm_61: amethyst (idle Quadro P5200), binaries cross-built on crystal with the scripts/build-pascal-p5200.md recipe, static-linked so the deployed v3 libs can't shadow.

🤖 Generated with Claude Code

marksverdhei added 4 commits June 11, 2026 23:46
… KV types

Gemma 4 global attention layers (head size 512) previously always dispatched
to the tile (pre-Volta) or MMA (Turing+) kernels, both of which require K/V
dequantized to F16 -- with a quantized KV cache that staging pass re-reads
and re-writes the entire per-layer KV every decode step.

Add vec kernel instances for D == 512 with matched q4_0/q8_0 KV types (the
vec kernel reads quantized KV directly) and dispatch to them for the small
batch sizes the vec kernel already owns on each arch. Gated on matched
quantized types and logit_softcap == 0 (vec only compiles softcap variants
for D == 128/256).

test-backend-ops previously had no quantized-KV FA coverage above head size
72; add Gemma4-shaped hs=512 cases (q8_0/q4_0, GQA, nb 1/2/3/32, sinks).
All 2899 FLASH_ATTN_EXT cases pass on CUDA (sm_86) vs CPU reference.
get_token_probabilities() sorted the entire vocabulary (262k entries for
Gemma) by logit on every emitted token when n_probs > 0, and the caller
then linearly scanned the sorted vector again to find the sampled token's
probability.

The softmax normalization only needs the max and the sum of the logits --
both O(n) without sorting. Select the top n_probs tokens with a partial
sort and return the sampled token's probability directly from the same
pass: O(V log V + V) per token becomes O(V + k log k).

No output change: same top-k ordering, same normalization over the full
candidate set.
Gemma 4 global-attention decode shapes (D=512, GQA=4, nb=1, q8_0/q4_0 KV)
for test-backend-ops perf mode. RTX 3090 (sm_86), MMA+dequant -> vec:

  kv=4096  q8_0: 155.0 -> 76.5 us/run (2.03x)   q4_0: 150.4 -> 90.9 us/run (1.65x)
  kv=8192  q8_0: 302.5 -> 145.3 us/run (2.08x)  q4_0: 277.8 -> 163.1 us/run (1.70x)
  kv=16384 q8_0: 558.5 -> 286.1 us/run (1.95x)  q4_0: 533.5 -> 298.3 us/run (1.79x)
The vec kernel re-reads K/V once per Q head; tile/MMA amortize K/V
reads across Q heads via the GQA optimization (at the cost of a
dequant-to-F16 staging pass for quantized KV). Measured crossover on
both sm_61 (TILE baseline) and sm_86 (MMA baseline): vec wins ~1.4-2.0x
per-op at gqa_ratio <= 4, but loses 1.1-2.5x at gqa_ratio == 16 -- the
Gemma 4 global-attention deployment shape (MQA, n_head_kv == 1).

Adds the deployment shape (nh=1, nr=16) to correctness and perf test
cases so the dispatch decision stays measurable.
@marksverdhei marksverdhei changed the title perf: FA vec kernels for D=512 quantized KV (~2x per-op) + server logprobs partial-sort perf: D=512 quantized-KV FA vec kernels (gated gqa_ratio<=4) + server logprobs partial-sort Jun 12, 2026
@marksverdhei marksverdhei merged commit 737e236 into ht Jun 12, 2026
1 of 12 checks passed
@marksverdhei marksverdhei deleted the perf/fattn-vec-512-logprobs branch June 12, 2026 18:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant