Skip to content

[DCP] Support Decode Context Parallel (DCP) for GQA with Flashinfer#25438

Merged
LucasWilkinson merged 18 commits intovllm-project:mainfrom
gjc0824:dcp-gqa-flashinfer
Nov 14, 2025
Merged

[DCP] Support Decode Context Parallel (DCP) for GQA with Flashinfer#25438
LucasWilkinson merged 18 commits intovllm-project:mainfrom
gjc0824:dcp-gqa-flashinfer

Conversation

@gjc0824
Copy link
Copy Markdown
Contributor

@gjc0824 gjc0824 commented Sep 23, 2025

Purpose

This PR adds Decode Context Parallel (DCP) support for GQA follwing PR #23734 and PR #24864. Current implementation based on FlashInfer Attention.

FlashInfer inserts the current query KV into the cache before computation. Each query then attends to both its own KV and the context KV on the local device, with LSE applied to correct the attention outputs.

  • In the prefill/partial-prefill stage, custom mask is added to support interleaved KV cache with FlashInfer.
q_lens = 8, total_lens = 25 , group_size = 4, local_rank = 0

stored kv cache
rank0: 0 4 8 12 16 20 24
rank1: 1 5 9 13 17 21
rank2: 2 6 10 14 18 22
rank3: 3 7 11 15 19 23

rank0 custom mask
q\kv    0      4      8     12    16      20      24
17   True,  True,  True,  True,  True,  False, False
18   True,  True,  True,  True,  True,  False, False
19   True,  True,  True,  True,  True,  False, False
20   True,  True,  True,  True,  True,  True,  False
21   True,  True,  True,  True,  True,  True,  False
22   True,  True,  True,  True,  True,  True,  False
23   True,  True,  True,  True,  True,  True,  False
24   True,  True,  True,  True,  True,  True,  True
  • In the decode stage, this PR follows the DCP decode approach from MLA, i.e., all-gathering Q and lse, then correcting the attn out before performing reduce-scatter.

Test Plan

Qwen/Qwen3-235B-A22B

export VLLM_ATTENTION_BACKEND='FLASHINFER'
vllm serve Qwen/Qwen3-235B-A22B --gpu-memory-utilization 0.9 --tensor-parallel-size 8 --decode-context-parallel-size 2

Test Result

  • gsm8k eval
dcp=1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8578|±  |0.0068|
|     |       |strict-match    |     5|exact_match|↑  |0.8415|±  |0.0071|

dcp=2
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8613|±  |0.0067|
|     |       |strict-match    |     5|exact_match|↑  |0.8469|±  |0.0070|

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces Decode Context Parallel (DCP) support for Grouped-Query Attention (GQA) with the FlashInfer backend, which is a valuable enhancement for distributed inference performance. The changes are comprehensive, covering configuration validation, modifications to the attention backend to support DCP-specific logic like query head gathering and LSE-based output correction, and the implementation of a custom attention mask for prefills. The addition of tests for a GQA model using the new functionality is also a great inclusion. The overall implementation is well-executed. I have a couple of suggestions to enhance code quality by addressing a dynamically assigned attribute and removing duplicated code.

@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@gjc0824 gjc0824 force-pushed the dcp-gqa-flashinfer branch 9 times, most recently from 540c862 to b9e9b41 Compare September 24, 2025 03:44
continue
K = ((rightmost - r) // p) + 1
j = torch.arange(K)
t = torch.arange(Q)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we generally avoid single character variable names; theyre ok though if there is supporting comment, can you please add comments explaining what the mask looks like and how it is constructed?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your review. We have added the comment about mask examples and algorithm explanation after vectorized improvements.

torch.int64).tolist()
r = self.dcp_rank
p = self.dcp_world_size
for i in range(num_prefills):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: is there a way we can vectorize this loop or replace it with a triton kernel? ideally we avoid python loops as they can be very slow and create GPU bubbles

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your valuable review. We have vectorized the "num_prefills" loop to avoid GPU bubbles. Looking forward to your further review.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if self.dcp_world_size > 1:
    # init custom mask for interleave kv cache
    # |-------total_lens----------|
    # |--context_lens--|--q_lens--|
    # Example: dcp_size=2, dcp_rank=0
    # For a SINGLE prefill seq, q_lens=3, total_lens=5
    # k_lens on RANK1 is (5 - 1 - 0) // 2 + 1 = 3
    # mask.shape = [q_lens, k_lens] = [3,3]
    # mask [[True, True, False],
    #       [True, True, False],
    #       [True, True, True]]
    dcp_rank = self.dcp_rank
    dcp_size = self.dcp_world_size

    q_lens = (qo_indptr_cpu[1:] - qo_indptr_cpu[:-1]).to(
            dtype=torch.int64, device=self.device)
    total_lens = seq_lens_cpu[prefill_start:prefill_start +
                num_prefills].to(dtype=torch.int64,
                device=self.device)
    context_lens = total_lens - q_lens
    # max indices for global sequences
    max_indices = total_lens - 1
    # if max_indices are smaller than dcp_rank,
    # current rank has no kv cache, is invalid,
    # the mask is skipped
    valid = (max_indices >= dcp_rank)
    assert torch.any(valid), "There is no valid sequence"

    # local kv lens on current dcp_rank
    k_lens = torch.div(max_indices - dcp_rank, 
                        dcp_size, 
                        rounding_mode="floor") + 1
    k_lens = torch.where(
        valid,
        k_lens,
        torch.zeros_like(k_lens))
    # vectorize operation
    # obtain the max length of all prefill reqs
    max_q = int(q_lens[valid].max().item())
    max_k = int(k_lens[valid].max().item())
    # generate local q and k indices
    q_indices = torch.arange(max_q, device=self.device)
    k_indices = torch.arange(max_k, device=self.device)
    # valid q and k indices of each reqs
    valid_q = valid[:, None] & \
        (q_indices[None, :] < q_lens[:, None])
    valid_k = valid[:, None] & \
        (k_indices[None, :] < k_lens[:, None])
    # where global q_indices >= global k_indices,
    # the mask is True
    # global q_indices = context_lens + local q_indices
    # global k_indices = local k_indcies * dcp_size + dcp_rank
    # ====> local k_indcies must be smaller or equal k_upper
    # k_upper=(context_lens + local q_indices - dcp_rank) // dcp_size
    k_upper = torch.div(
        context_lens[:, None] + q_indices - dcp_rank,
        dcp_size, rounding_mode="floor")
    k_upper = torch.where(
            valid_q,
            torch.clamp(k_upper, min=-1),
            k_upper.new_full(k_upper.shape, -1))
    mask = (k_indices[None, None, :] <= k_upper[:, :, None]) \
            & (k_upper[:, :, None] >= 0)
    valid_positions = valid_q[:, :, None] & valid_k[:, None, :]
    # flashinfer backend needs flattened format
    custom_mask = torch.masked_select(mask, valid_positions)

@LucasWilkinson
Copy link
Copy Markdown
Collaborator

Apologies for the delayed review! left a couple nits; overall its looking pretty good though

@mergify
Copy link
Copy Markdown

mergify bot commented Oct 9, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @gjc0824.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 9, 2025
@gjc0824
Copy link
Copy Markdown
Contributor Author

gjc0824 commented Oct 14, 2025

Apologies for the delayed review! left a couple nits; overall its looking pretty good though

Hi @LucasWilkinson . Could you re-review this PR and give the final sign off ? Thanks!

@gjc0824 gjc0824 reopened this Oct 14, 2025
@github-project-automation github-project-automation bot moved this from Done to To Triage in gpt-oss Issues & Enhancements Oct 14, 2025
@mergify
Copy link
Copy Markdown

mergify bot commented Oct 14, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @gjc0824.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Copy link
Copy Markdown
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies for the delay! Overall looks pretty good so far but I think we should land #26696 first (seems more important and this can build on that), thoughts?


self.num_qo_heads = self.model_config.get_num_attention_heads(
self.vllm_config.parallel_config
try:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor Author

@gjc0824 gjc0824 Nov 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for delay. We were blocked by #26696 which will effect our local kv lengths. Now #26696 has been merged and we can continue to improve this work.
Compared to the previous commit, we have made significant improvements and employed the similar implementation as #24864 The main reason for making this improvement is that we found that inducing a custom mask greatly slows down the prefill_wrapper.run() operator (2ms -> 10ms when seq_len=32k). For avoid the custom mask, we divide the computation of prefill stage into context and new tokens.

  # |---------- context_len ----------|--- query_len---|
  # |------------ context-- ----------|-- newtokens ---|
  • For newtokens, query can compute with kv in causal=True mode without other communications in DCP group.
  • For context, the KV is distributed across different DCP ranks and causal mask is not required. We follows the #24864, i.e., all-gathering Q and lse, then correcting the attn out before performing reduce-scatter.
    This implementation obtains the memory space by splitting kvcache and not impair performance much.

block_table_tensor = common_attn_metadata.block_table_tensor

if self.dcp_world_size > 1:
seq_lens_np = seq_lens_np // self.dcp_world_size + (
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we land #26696 first and then update this to use the dcp_local_seq_lens computed in the model runner?

Copy link
Copy Markdown
Contributor Author

@gjc0824 gjc0824 Nov 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. We can obtain the local seq_lens by get_dcp_local_seq_lens at this.

if self.dcp_world_size > 1:
prefill_query = get_dcp_group().all_gather(
prefill_query.contiguous(), dim=1
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I guess this is fine but I guess the name "decode context parallel" is falling apart a bit here 😞

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In last implementation, this issue does exist. But in current version, additional DCP communication operations during the prefill phase only occur when context tokens are present. So I think this may be suitable.

],
"bigcode/gpt_bigcode-santacoder": [
CPTestSettings.detailed(),
CPTestSettings.detailed(tp_base=2),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to keep the default backend for CI.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for delay. We were blocked by #26696 which will effect our local kv lengths. Now #26696 has been merged and we can continue to improve this work.
Sure. We should keep the default backend. Thanks!

class CPTestOptions(NamedTuple):
multi_node_only: bool
load_format: str | None = None
attn_backend: str = "FLASH_ATTN"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MLA can't use "FLASH_ATTN" backend, so the default value should not be set.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. We have improved it. Thanks for your comment.

gjc0824 and others added 7 commits November 9, 2025 21:47
Signed-off-by: Jingchun Gao <gaojingchun1@huawei.com>
Signed-off-by: Jingchun Gao <gaojingchun1@huawei.com>
Signed-off-by: Jingchun Gao <gaojingchun1@huawei.com>
Signed-off-by: Jingchun Gao <gaojingchun1@huawei.com>
Signed-off-by: Jingchun Gao <gaojingchun1@huawei.com>
Signed-off-by: Jingchun Gao <gaojingchun1@huawei.com>
Signed-off-by: Jingchun Gao <gaojingchun1@huawei.com>
fixed_split_size=self.prefill_fixed_split_size,
disable_split_kv=self.disable_split_kv,
)
kv_query_indptr_cpu = qo_indptr_cpu.clone()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is a clone needed? can't we just do kv_indptr=qo_indptr_cpu.to(self.device)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. It is not needed. We verified that removing it does not affect the model precision.

assert not isinstance(attn_metadata.prefill_wrapper, dict)
attn_metadata.prefill_wrapper.plan(
qo_indptr_cpu.to(self.device),
paged_kv_indptr_cpu.to(self.device),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why add .to(self.device); last i checked FlashInfer prefers CPU tensors otherwise we can get D2H copies in the plan: #21137

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was necessary in our last implementation with the custom mask, otherwise a device error would occur in the BatchPrefillWithPagedKVCacheWrapper. Now we can freely remove it.

BatchPrefillWithPagedKVCacheWrapper | BatchPrefillWithRaggedKVCacheWrapper,
]
| None
) = None
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this type signature is complicated and repeated alot 😞 ; maybe we could make our wrapper shim? like

class BatchDCPPrefillWrapper:
       self._new_tokens: BatchPrefillWithRaggedKVCacheWrapper
       self._context: BatchPrefillWithPagedKVCacheWrapper

       def plan(....):
             self._new_tokens.plan(...)
             self._context.plan(...)
      
       def run(...):
             new = self._new_tokens.run(...)
             context = self._context.run(...)
             return merge_attn_states(new, context)

then we can make this prefill_wrapper: BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper

thoughts?

Copy link
Copy Markdown
Contributor Author

@gjc0824 gjc0824 Nov 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for you valuable comment. We added the new class BatchDCPPrefillWrapper for concise wrapper. The main improvement is here.

@@ -679,24 +756,61 @@ def build(
attn_metadata.max_q_len_prefill = int(query_lens_prefill.max().item())

if not attn_metadata.prefill_use_trtllm:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to force prefill_use_trtllm to False when DCP is enabled?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trllm cannot return lse so it is not supported in DCP. Now we directly disable it in vllm/utils/flashinfer.py.

@mergify
Copy link
Copy Markdown

mergify bot commented Nov 11, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @gjc0824.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

gjc0824 and others added 3 commits November 11, 2025 21:43
Signed-off-by: Jingchun Gao <gaojingchun1@huawei.com>
Signed-off-by: Jingchun Gao <63247409+gjc0824@users.noreply.github.com>
Signed-off-by: Jingchun Gao <gaojingchun1@huawei.com>
Copy link
Copy Markdown
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall is looking much better thank you! what is the issue with block interleave_size > 1 ?

# Decode context parallel is not supported
if dcp_world_size > 1:
logger.warning_once(
"Trtllm not support lse, please use flash attention or FlashInfer backend."
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Trtllm does not support returning LSE and as a result does not support DCP, reverting to FlashInfer"

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! We update the warning information for a clearer explanation.

@pisceskkk
Copy link
Copy Markdown
Contributor

pisceskkk commented Nov 12, 2025

what is the issue with block interleave_size > 1 ?

When handling contexts for chunked prefill, we split contexts into chunks based on the workspace. However, the recent refactoring of the reorg_kvcache function for adapting to interleave_size > 1overlooked this aspect, resulting in incorrect chunk sizes being extracted. I have submitted a new PR #28526 to address these issues.

Signed-off-by: Jingchun Gao <gaojingchun1@huawei.com>
Copy link
Copy Markdown
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM; thanks for the cleanups!

(please resolve conflicts)

@mergify
Copy link
Copy Markdown

mergify bot commented Nov 13, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @gjc0824.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build nvidia qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

Status: Done
Status: Done
Status: Done
Status: Done

Development

Successfully merging this pull request may close these issues.

4 participants