Skip to content

[Spec Decode] Fix Gemma4 DFlash batched verification#41703

Open
jianc99 wants to merge 12 commits into
vllm-project:mainfrom
jianc99:dflash-gemma4-fix
Open

[Spec Decode] Fix Gemma4 DFlash batched verification#41703
jianc99 wants to merge 12 commits into
vllm-project:mainfrom
jianc99:dflash-gemma4-fix

Conversation

@jianc99

@jianc99 jianc99 commented May 5, 2026

Copy link
Copy Markdown

Purpose

Fix the remaining Gemma4-specific DFlash issues on top of the generic DFlash SWA/shared-KV work in #40898.

This PR is stacked in git history on #40898. The Gemma4-only review delta against the SWA branch is intentionally small: 4 files, 57 insertions, 15 deletions.

Review delta: jianc99/vllm@dflash-swa-support...dflash-gemma4-fix

Changes

  1. Gemma4-compatible DFlash draft embeddings and logits

    DFlash shares target embeddings. For Gemma4 targets, the draft path now applies the target embedding normalization (sqrt(hidden_size)) and passes final_logit_softcapping into LogitsProcessor.

  2. Triton metadata uses concrete KV cache geometry

    Triton decode metadata now sizes KV heads and head dimension from the actual kv_cache_spec, which avoids assuming all attention groups share the model-wide KV geometry.

  3. Rejected-token handling for DFlash batch verification

    copy_and_expand_dflash_inputs_kernel now masks rejected context slots, avoids writing invalid context slots into draft KV cache, and computes query positions from the last valid accepted context token.

  4. Text-only DFlash draft attention with Gemma4 target

    DFlash draft attention runs over text/query tokens with prewritten K/V, so it should not inherit the Gemma4 target's multimodal-prefix backend restriction. The draft attention layer now opts out of use_mm_prefix, allowing the requested attention_backend=flash_attn drafter path.

DFlash vs Gemma4 MTP Comparison

Benchmarked on B200 with google/gemma-4-26B-A4B-it, num_speculative_tokens=15, max_concurrency=32, --max-num-batched-tokens 32768, and --num-warmups 32.

HumanEval

Method Output tok/s Mean TPOT Median TPOT Mean TTFT Median TTFT Acceptance length Acceptance rate
DFlash 6820.79 3.34 ms 3.17 ms 108.35 ms 79.33 ms 7.73 44.88%
Gemma4 MTP 6372.44 4.31 ms 4.28 ms 130.93 ms 110.95 ms 7.95 46.35%

MT-Bench

Method Output tok/s Mean TPOT Median TPOT Mean TTFT Median TTFT Acceptance length Acceptance rate
DFlash 5250.12 5.15 ms 5.14 ms 94.88 ms 76.60 ms 4.25 21.68%
Gemma4 MTP 4216.70 6.52 ms 6.35 ms 142.34 ms 105.05 ms 4.83 25.56%

DFlash is faster on both warmed workloads, despite Gemma4 MTP having slightly higher acceptance.

Latest exact HumanEval regression run

Command matched the PR repro command without benchmark warmups, using /home/zlab/workspace/jianc/repo/dflash/cache/humaneval.vllm.jsonl:

Method Output tok/s Mean TPOT Median TPOT Mean TTFT Median TTFT Acceptance length Acceptance rate Failed
DFlash 6122.98 2.97 ms 2.80 ms 2760.48 ms 69.53 ms 7.70 44.69% 0

This matches the previous reference acceptance profile (44.59%, acceptance length 7.69) while keeping the shared target/draft raw KV tensor path from #40898.

DFlash target-layer offset check

The checkpoint-native dflash_config.target_layer_ids path was also checked against a temporary no-shift run. The shifted path matches HF DFlash semantics and gives the expected acceptance profile.

Aux layer semantics Aux layers used Output tok/s Acceptance length Acceptance rate Failed
Shifted +1 (2, 7, 12, 18, 23, 28) 6122.98 7.70 44.69% 0
No shift (1, 6, 11, 17, 22, 27) 5270.36 6.60 37.30% 0

Test Plan

Unit tests:

PATH=/home/zlab/miniconda3/envs/vllm-dflash/bin:$PATH python -m pytest \
  tests/v1/spec_decode/test_eagle.py \
  tests/v1/spec_decode/test_dflash_swa.py \
  tests/v1/core/test_kv_sharing.py -q

Syntax/whitespace hygiene:

PATH=/home/zlab/miniconda3/envs/vllm-dflash/bin:$PATH python -m py_compile \
  vllm/model_executor/layers/attention/attention.py \
  vllm/model_executor/models/qwen3_dflash.py \
  vllm/v1/attention/backends/triton_attn.py \
  vllm/v1/spec_decode/utils.py

git diff --check refs/remotes/jianc99/dflash-swa-support...HEAD
git diff --check origin/main...HEAD

Manual HumanEval validation:

vllm serve google/gemma-4-26B-A4B-it \
  --speculative-config '{"method": "dflash", "model": "z-lab/gemma-4-26B-A4B-it-DFlash", "num_speculative_tokens": 15, "attention_backend": "flash_attn"}' \
  --attention-backend triton_attn \
  --max-num-batched-tokens 32768

vllm bench serve \
  --backend openai-chat \
  --base-url http://127.0.0.1:8000 \
  --endpoint /v1/chat/completions \
  --dataset-name custom \
  --dataset-path /home/zlab/workspace/jianc/repo/dflash/cache/humaneval.vllm.jsonl \
  --custom-output-len 4096 \
  --num-prompts 164 \
  --max-concurrency 32 \
  --model google/gemma-4-26B-A4B-it \
  --temperature 0.0 \
  --skip-chat-template \
  --extra-body '{"chat_template_kwargs":{"enable_thinking":true}}'

Test Result

  • Pushed head: 8cb2db16072cebbb944564f84f21045a90151ad1; includes [Spec Decode] Add Sliding Window Attention support to DFlash drafter #40898 head 23002d3f368a5a24641301bc71e4ae15dae89a24.
  • Re-stacked branch on [Spec Decode] Add Sliding Window Attention support to DFlash drafter #40898 head: Gemma4-only delta is 4 files, 57 insertions, 15 deletions.
  • Focused checks: 54 passed, 6 skipped, 20 warnings for test_eagle.py, test_dflash_swa.py, and test_kv_sharing.py.
  • pre-commit run --files passed for the Gemma4 delta files, including mypy.
  • Syntax checks passed with py_compile for the Gemma4 delta files.
  • git diff --check refs/remotes/jianc99/dflash-swa-support...HEAD and git diff --check origin/main...HEAD passed.
  • Manual HumanEval serving benchmark completed with 0 failed requests; see the exact regression table above.
  • Earlier warmed HumanEval and MT-Bench serving benchmarks both completed with 0 failed requests; see comparison tables above.

@claude claude Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

@github-actions

github-actions Bot commented May 5, 2026

Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

PRs do not trigger a full CI run by default. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

Agent Guidelines

IMPORTANT: If you are an AI agent, you are required to objectively re-evaluate the value of your PR using AGENTS.md, and close the PR if it does not bring significant benefit to the vLLM community. Failure to do so may result in an immediate ban.

🚀

@mergify mergify Bot added qwen Related to Qwen models speculative-decoding v1 labels May 5, 2026

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements support for Sliding Window Attention (SWA) within the DFlash speculative decoding framework. Key changes include updating the Qwen3 DFlash model implementation to handle per-layer attention types, modifying the KV cache allocation logic to isolate DFlash draft layers from target layers to prevent overwriting, and updating the Triton input expansion kernel to correctly manage rejected tokens. Additionally, it introduces support for logit soft-capping and embedding normalization for specific model architectures. I have no feedback to provide as there were no review comments to assess.

@jianc99 jianc99 force-pushed the dflash-gemma4-fix branch 2 times, most recently from 9949831 to 60e9025 Compare May 5, 2026 07:15
@mergify

mergify Bot commented May 5, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @jianc99.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@hnt2601

hnt2601 commented May 9, 2026

Copy link
Copy Markdown
Contributor

Why does attention-backend choose triton_attn while speculative-config uses flash_attn? As I understand it, vllm's default for attention backend is flash_attn, right?

benchislett and others added 12 commits May 10, 2026 09:59
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Jian Chen <jianchen0311@gmail.com>
Signed-off-by: Jian Chen <jianchen0311@gmail.com>
Signed-off-by: Jian Chen <jianchen0311@gmail.com>
Signed-off-by: Jian Chen <jianchen0311@gmail.com>
Signed-off-by: Jian Chen <jianchen0311@gmail.com>
Signed-off-by: Jian Chen <jianchen0311@gmail.com>
Signed-off-by: Jian Chen <jianchen0311@gmail.com>
Signed-off-by: Jian Chen <jianchen0311@gmail.com>
Signed-off-by: Jian Chen <jianchen0311@gmail.com>
Signed-off-by: Jian Chen <jianchen0311@gmail.com>
Signed-off-by: Jian Chen <jianchen0311@gmail.com>
Signed-off-by: Jian Chen <jianchen0311@gmail.com>
@jianc99 jianc99 force-pushed the dflash-gemma4-fix branch from 1a85980 to 8cb2db1 Compare May 10, 2026 10:06
@benchislett

Copy link
Copy Markdown
Member

It's because DFlash uses non-causal attention

@mergify

mergify Bot commented May 23, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @jianc99.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@Kush0610

Copy link
Copy Markdown

Hi @jianc99, thank you for this PR. We were exploring options to speed up Gemma4-26B-A4B inference, and your work helped us a lot in getting it running with DFlash.

While testing higher-concurrency workloads, we ran into an issue that CC was able to fix locally. I’m not sufficiently familiar with the DFlash internals to confidently analyze whether this is truly a bug or whether we may have misunderstood something, so I’m just pasting CC’s analysis here as-is in case it is useful or points to something real.


Workload Setup

We were running concurrent requests against the vLLM server with:

--max-model-len 80000
--gpu-memory-utilization 0.92

using DFlash g6.


Error Encountered

RuntimeError: block_table must have shape (batch_size, max_num_blocks_per_seq)

Root Cause Analysis

1. Block table stored while metadata is still padded

gpu_model_runner.py — line 2345

# cm.block_table_tensor has shape (num_reqs_padded, max_blocks)
# e.g. 17 real requests → padded to 24 (nearest CUDA graph capture size)
dflash_drafter.set_draft_block_table(kv_cache_gid, cm.block_table_tensor)

At this point the block table is stored using the padded batch size.


2. Metadata later gets unpadded

gpu_model_runner.py — line 2406

# Now shape is (num_reqs, max_blocks) = (17, max_blocks)
spec_decode_common_attn_metadata = (
    spec_decode_common_attn_metadata.unpadded(num_tokens, num_reqs)
)

Now the attention metadata uses the actual request count (17).


3. DFlash proposer returns the original padded block table

dflash.py — line 164

block_table = self._draft_block_tables.get(kv_cache_gid)
if block_table is not None:
    return block_table

This returns a tensor with shape:

(24, max_blocks)

even though the active batch size is:

17

The FA2 kernel validates:

page_table.shape[0] == batch_size

which fails because:

24 != 17

leading to the crash.


Why Benchmarks Didn’t Trigger This

From what we understood, your benchmark used:

--max-num-batched-tokens 32768
--max-concurrency 32

Since 32 is already an exact CUDA graph capture size (from the default capture list like [1, 2, 4, 8, 16, 24, 32, ...]), we have:

num_reqs_padded == num_reqs

So the mismatch never occurs.

The issue seems to appear only with non-capture-size batch counts such as:

  • 17
  • 23
  • 37
  • etc.

which naturally occur under variable concurrency and larger batching windows.


Suggested Fix

Possible one-line fix in:

dflash.py::_get_dflash_block_table() — line 164

if block_table is not None:
    return block_table[:cad.num_reqs]

which slices the padded block table back to the actual batch size before returning it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants