perf: D=512 quantized-KV FA vec kernels (gated gqa_ratio<=4) + server logprobs partial-sort#102
Merged
Merged
Conversation
added 4 commits
June 11, 2026 23:46
… KV types Gemma 4 global attention layers (head size 512) previously always dispatched to the tile (pre-Volta) or MMA (Turing+) kernels, both of which require K/V dequantized to F16 -- with a quantized KV cache that staging pass re-reads and re-writes the entire per-layer KV every decode step. Add vec kernel instances for D == 512 with matched q4_0/q8_0 KV types (the vec kernel reads quantized KV directly) and dispatch to them for the small batch sizes the vec kernel already owns on each arch. Gated on matched quantized types and logit_softcap == 0 (vec only compiles softcap variants for D == 128/256). test-backend-ops previously had no quantized-KV FA coverage above head size 72; add Gemma4-shaped hs=512 cases (q8_0/q4_0, GQA, nb 1/2/3/32, sinks). All 2899 FLASH_ATTN_EXT cases pass on CUDA (sm_86) vs CPU reference.
get_token_probabilities() sorted the entire vocabulary (262k entries for Gemma) by logit on every emitted token when n_probs > 0, and the caller then linearly scanned the sorted vector again to find the sampled token's probability. The softmax normalization only needs the max and the sum of the logits -- both O(n) without sorting. Select the top n_probs tokens with a partial sort and return the sampled token's probability directly from the same pass: O(V log V + V) per token becomes O(V + k log k). No output change: same top-k ordering, same normalization over the full candidate set.
Gemma 4 global-attention decode shapes (D=512, GQA=4, nb=1, q8_0/q4_0 KV) for test-backend-ops perf mode. RTX 3090 (sm_86), MMA+dequant -> vec: kv=4096 q8_0: 155.0 -> 76.5 us/run (2.03x) q4_0: 150.4 -> 90.9 us/run (1.65x) kv=8192 q8_0: 302.5 -> 145.3 us/run (2.08x) q4_0: 277.8 -> 163.1 us/run (1.70x) kv=16384 q8_0: 558.5 -> 286.1 us/run (1.95x) q4_0: 533.5 -> 298.3 us/run (1.79x)
The vec kernel re-reads K/V once per Q head; tile/MMA amortize K/V reads across Q heads via the GQA optimization (at the cost of a dequant-to-F16 staging pass for quantized KV). Measured crossover on both sm_61 (TILE baseline) and sm_86 (MMA baseline): vec wins ~1.4-2.0x per-op at gqa_ratio <= 4, but loses 1.1-2.5x at gqa_ratio == 16 -- the Gemma 4 global-attention deployment shape (MQA, n_head_kv == 1). Adds the deployment shape (nh=1, nr=16) to correctness and perf test cases so the dispatch decision stays measurable.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Two perf changes from a fresh review of the hot paths, now fully validated on both fleet architectures (RTX 3090 sm_86, Quadro P5200 sm_61). The deployment-shape validation materially changed the story for change 1 — see below. Change 2 (server logprobs) stands as originally measured.
1. CUDA: FA vec kernel instances for D == 512 with matched quantized KV — now gated to
gqa_ratio <= 4Gemma 4 global-attention layers (head size 512) always dispatch to the tile (pre-Volta) or MMA (Turing+) FA kernels — both require K/V in F16, so with a quantized KV cache the entire per-layer global KV is dequantized into an F16 staging buffer on every decode step. The vec kernel reads q4_0/q8_0 directly but had no D=512 instances (only 64/128/256). This PR adds matched q4_0/q8_0 D=512 vec instances.
What deployment-shape validation found: the original perf cases used
gqa_ratio = 4. The real Gemma 4 global layers are MQA (head_count_kv = 1,gqa_ratio = 16). At that shape the vec kernel re-reads the single KV head once per Q head (16×), while TILE/MMA pay the dequant staging once and amortize K/V reads across Q heads via the GQA optimization (8 Q-heads per pass). The vec kernel loses there — on both architectures:Per-op, deployment shape (MQA, nh_kv=1, ratio 16, nb=1, q8_0):
(q4_0 is worse still: up to −159% on sm_61.) End-to-end this was a measurable ~4% tg regression at d32768 on the 3090, reproduced in both run orders.
Per-op, gqa_ratio = 4 (the shape where vec wins), q8_0:
(q4_0 at ratio 4: sm_86 keeps a 1.65–1.79x win; sm_61 is a wash, 1.00–1.04x — the in-kernel q4_0 dequant costs roughly what the staging pass saved on Pascal.)
Resolution (3c8a368): the D=512 vec dispatch requires
gqa_ratio <= 4. The deployment shape falls back to the baseline kernels — verified per-op parity (≤2%) and e2e parity on both arches. Net effect for the current fleet: neutral on FA (no regression, no win), with the kernels + dispatch + test coverage in place for low-ratio D=512 shapes and as the substrate for the real fix.Follow-up that would capture the actual prize: a quantized-KV-direct kernel with GQA packing (vec + ncols2, or TILE reading q8_0/q4_0 natively) would cut Gemma 4 global-layer FA traffic to ~1/4 of either current path — the per-op headroom at MQA-16 is ~4x. That is a kernel-engineering task, not a dispatch tweak.
2. Server: avoid full-vocab sort in get_token_probabilities
When a client requests logprobs (
n_probs > 0), the server did a fullstd::sortof the entire vocabulary (262k entries for Gemma) per emitted token, then the caller linearly re-scanned the sorted vector to find the sampled token. Softmax normalization only needs max+sum (O(V), no ordering); top-k comes fromstd::partial_sort. O(V log V + V) → O(V + k log k) per token, identical output.Validation
test-backend-ops test -o FLASH_ATTN_EXTvs CPU reference: 2903/2903 pass on sm_86 (RTX 3090) and full sweep pass on sm_61 (Quadro P5200) — including new hs=512 quantized-KV cases at both gqa_ratio=4 and the deployment MQA-16 shape (q8_0/q4_0, nb={1,2,3,32}, sinks variant). Quantized-KV FA previously had zero hs=512 coverage (only 64/72).sm_61 run exercises the
cols_per_block=2vec instantiation (pre-Volta dispatch, ne1<=2); sm_86 exercisescols_per_block=1.E2e llama-bench A/B (gemma-4-12B QAT Q4_K_XL,
-fa 1 -ctk q8_0 -ctv q8_0): gated build ≡ baseline within noise on sm_86; ungated build reproduced the predicted regression (3090: 58.6→56.3 tg @d32k). sm_61 gated per-op parity confirmed (228/397/752 us vs 231/404/766 baseline).sm_61 3-way e2e (amethyst, idle P5200, ngl=99, r=2 — stddev ≤0.31 throughout):
The ungated dispatch would have cost the Pascal fleet 11.5% decode at 32k depth; the gate restores exact parity.
Deployment-shape perf cases added to
make_test_cases_perfso this dispatch decision stays measurable.generate_cu_files.pyupdated so regeneration reproduces the instance files byte-for-byte.Numbers provenance
🤖 Generated with Claude Code