[Spec Decode] Fix Gemma4 DFlash batched verification#41703
Conversation
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in PRs do not trigger a full CI run by default. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. Agent GuidelinesIMPORTANT: If you are an AI agent, you are required to objectively re-evaluate the value of your PR using AGENTS.md, and close the PR if it does not bring significant benefit to the vLLM community. Failure to do so may result in an immediate ban. 🚀 |
There was a problem hiding this comment.
Code Review
This pull request implements support for Sliding Window Attention (SWA) within the DFlash speculative decoding framework. Key changes include updating the Qwen3 DFlash model implementation to handle per-layer attention types, modifying the KV cache allocation logic to isolate DFlash draft layers from target layers to prevent overwriting, and updating the Triton input expansion kernel to correctly manage rejected tokens. Additionally, it introduces support for logit soft-capping and embedding normalization for specific model architectures. I have no feedback to provide as there were no review comments to assess.
9949831 to
60e9025
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
|
Why does attention-backend choose triton_attn while speculative-config uses flash_attn? As I understand it, vllm's default for attention backend is flash_attn, right? |
abf8219 to
ea011d6
Compare
9aa76fe to
a20847b
Compare
a20847b to
1a85980
Compare
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com> Signed-off-by: Jian Chen <jianchen0311@gmail.com>
Signed-off-by: Jian Chen <jianchen0311@gmail.com>
Signed-off-by: Jian Chen <jianchen0311@gmail.com>
Signed-off-by: Jian Chen <jianchen0311@gmail.com>
Signed-off-by: Jian Chen <jianchen0311@gmail.com>
Signed-off-by: Jian Chen <jianchen0311@gmail.com>
Signed-off-by: Jian Chen <jianchen0311@gmail.com>
Signed-off-by: Jian Chen <jianchen0311@gmail.com>
Signed-off-by: Jian Chen <jianchen0311@gmail.com>
Signed-off-by: Jian Chen <jianchen0311@gmail.com>
Signed-off-by: Jian Chen <jianchen0311@gmail.com>
Signed-off-by: Jian Chen <jianchen0311@gmail.com>
1a85980 to
8cb2db1
Compare
|
It's because DFlash uses non-causal attention |
|
This pull request has merge conflicts that must be resolved before it can be |
|
Hi @jianc99, thank you for this PR. We were exploring options to speed up While testing higher-concurrency workloads, we ran into an issue that CC was able to fix locally. I’m not sufficiently familiar with the DFlash internals to confidently analyze whether this is truly a bug or whether we may have misunderstood something, so I’m just pasting CC’s analysis here as-is in case it is useful or points to something real. Workload SetupWe were running concurrent requests against the vLLM server with: --max-model-len 80000
--gpu-memory-utilization 0.92using DFlash g6. Error EncounteredRuntimeError: block_table must have shape (batch_size, max_num_blocks_per_seq)Root Cause Analysis1. Block table stored while metadata is still padded
# cm.block_table_tensor has shape (num_reqs_padded, max_blocks)
# e.g. 17 real requests → padded to 24 (nearest CUDA graph capture size)
dflash_drafter.set_draft_block_table(kv_cache_gid, cm.block_table_tensor)At this point the block table is stored using the padded batch size. 2. Metadata later gets unpadded
# Now shape is (num_reqs, max_blocks) = (17, max_blocks)
spec_decode_common_attn_metadata = (
spec_decode_common_attn_metadata.unpadded(num_tokens, num_reqs)
)Now the attention metadata uses the actual request count ( 3. DFlash proposer returns the original padded block table
block_table = self._draft_block_tables.get(kv_cache_gid)
if block_table is not None:
return block_tableThis returns a tensor with shape: (24, max_blocks)even though the active batch size is: 17The FA2 kernel validates: page_table.shape[0] == batch_sizewhich fails because: 24 != 17leading to the crash. Why Benchmarks Didn’t Trigger ThisFrom what we understood, your benchmark used: --max-num-batched-tokens 32768
--max-concurrency 32Since num_reqs_padded == num_reqsSo the mismatch never occurs. The issue seems to appear only with non-capture-size batch counts such as:
which naturally occur under variable concurrency and larger batching windows. Suggested FixPossible one-line fix in:
if block_table is not None:
return block_table[:cad.num_reqs]which slices the padded block table back to the actual batch size before returning it. |
Purpose
Fix the remaining Gemma4-specific DFlash issues on top of the generic DFlash SWA/shared-KV work in #40898.
This PR is stacked in git history on #40898. The Gemma4-only review delta against the SWA branch is intentionally small: 4 files, 57 insertions, 15 deletions.
Review delta: jianc99/vllm@dflash-swa-support...dflash-gemma4-fix
Changes
Gemma4-compatible DFlash draft embeddings and logits
DFlash shares target embeddings. For Gemma4 targets, the draft path now applies the target embedding normalization (
sqrt(hidden_size)) and passesfinal_logit_softcappingintoLogitsProcessor.Triton metadata uses concrete KV cache geometry
Triton decode metadata now sizes KV heads and head dimension from the actual
kv_cache_spec, which avoids assuming all attention groups share the model-wide KV geometry.Rejected-token handling for DFlash batch verification
copy_and_expand_dflash_inputs_kernelnow masks rejected context slots, avoids writing invalid context slots into draft KV cache, and computes query positions from the last valid accepted context token.Text-only DFlash draft attention with Gemma4 target
DFlash draft attention runs over text/query tokens with prewritten K/V, so it should not inherit the Gemma4 target's multimodal-prefix backend restriction. The draft attention layer now opts out of
use_mm_prefix, allowing the requestedattention_backend=flash_attndrafter path.DFlash vs Gemma4 MTP Comparison
Benchmarked on B200 with
google/gemma-4-26B-A4B-it,num_speculative_tokens=15,max_concurrency=32,--max-num-batched-tokens 32768, and--num-warmups 32.HumanEval
MT-Bench
DFlash is faster on both warmed workloads, despite Gemma4 MTP having slightly higher acceptance.
Latest exact HumanEval regression run
Command matched the PR repro command without benchmark warmups, using
/home/zlab/workspace/jianc/repo/dflash/cache/humaneval.vllm.jsonl:This matches the previous reference acceptance profile (
44.59%, acceptance length7.69) while keeping the shared target/draft raw KV tensor path from #40898.DFlash target-layer offset check
The checkpoint-native
dflash_config.target_layer_idspath was also checked against a temporary no-shift run. The shifted path matches HF DFlash semantics and gives the expected acceptance profile.+1(2, 7, 12, 18, 23, 28)(1, 6, 11, 17, 22, 27)Test Plan
Unit tests:
PATH=/home/zlab/miniconda3/envs/vllm-dflash/bin:$PATH python -m pytest \ tests/v1/spec_decode/test_eagle.py \ tests/v1/spec_decode/test_dflash_swa.py \ tests/v1/core/test_kv_sharing.py -qSyntax/whitespace hygiene:
PATH=/home/zlab/miniconda3/envs/vllm-dflash/bin:$PATH python -m py_compile \ vllm/model_executor/layers/attention/attention.py \ vllm/model_executor/models/qwen3_dflash.py \ vllm/v1/attention/backends/triton_attn.py \ vllm/v1/spec_decode/utils.py git diff --check refs/remotes/jianc99/dflash-swa-support...HEAD git diff --check origin/main...HEADManual HumanEval validation:
Test Result
8cb2db16072cebbb944564f84f21045a90151ad1; includes [Spec Decode] Add Sliding Window Attention support to DFlash drafter #40898 head23002d3f368a5a24641301bc71e4ae15dae89a24.54 passed, 6 skipped, 20 warningsfortest_eagle.py,test_dflash_swa.py, andtest_kv_sharing.py.pre-commit run --filespassed for the Gemma4 delta files, including mypy.py_compilefor the Gemma4 delta files.git diff --check refs/remotes/jianc99/dflash-swa-support...HEADandgit diff --check origin/main...HEADpassed.