[DCP] Support Decode Context Parallel (DCP) for GQA with Flashinfer#25438
[DCP] Support Decode Context Parallel (DCP) for GQA with Flashinfer#25438LucasWilkinson merged 18 commits intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces Decode Context Parallel (DCP) support for Grouped-Query Attention (GQA) with the FlashInfer backend, which is a valuable enhancement for distributed inference performance. The changes are comprehensive, covering configuration validation, modifications to the attention backend to support DCP-specific logic like query head gathering and LSE-based output correction, and the implementation of a custom attention mask for prefills. The addition of tests for a GQA model using the new functionality is also a great inclusion. The overall implementation is well-executed. I have a couple of suggestions to enhance code quality by addressing a dynamically assigned attribute and removing duplicated code.
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run You ask your reviewers to trigger select CI tests on top of Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. 🚀 |
540c862 to
b9e9b41
Compare
| continue | ||
| K = ((rightmost - r) // p) + 1 | ||
| j = torch.arange(K) | ||
| t = torch.arange(Q) |
There was a problem hiding this comment.
nit: we generally avoid single character variable names; theyre ok though if there is supporting comment, can you please add comments explaining what the mask looks like and how it is constructed?
There was a problem hiding this comment.
Thank you for your review. We have added the comment about mask examples and algorithm explanation after vectorized improvements.
| torch.int64).tolist() | ||
| r = self.dcp_rank | ||
| p = self.dcp_world_size | ||
| for i in range(num_prefills): |
There was a problem hiding this comment.
nit: is there a way we can vectorize this loop or replace it with a triton kernel? ideally we avoid python loops as they can be very slow and create GPU bubbles
There was a problem hiding this comment.
Thank you for your valuable review. We have vectorized the "num_prefills" loop to avoid GPU bubbles. Looking forward to your further review.
There was a problem hiding this comment.
if self.dcp_world_size > 1:
# init custom mask for interleave kv cache
# |-------total_lens----------|
# |--context_lens--|--q_lens--|
# Example: dcp_size=2, dcp_rank=0
# For a SINGLE prefill seq, q_lens=3, total_lens=5
# k_lens on RANK1 is (5 - 1 - 0) // 2 + 1 = 3
# mask.shape = [q_lens, k_lens] = [3,3]
# mask [[True, True, False],
# [True, True, False],
# [True, True, True]]
dcp_rank = self.dcp_rank
dcp_size = self.dcp_world_size
q_lens = (qo_indptr_cpu[1:] - qo_indptr_cpu[:-1]).to(
dtype=torch.int64, device=self.device)
total_lens = seq_lens_cpu[prefill_start:prefill_start +
num_prefills].to(dtype=torch.int64,
device=self.device)
context_lens = total_lens - q_lens
# max indices for global sequences
max_indices = total_lens - 1
# if max_indices are smaller than dcp_rank,
# current rank has no kv cache, is invalid,
# the mask is skipped
valid = (max_indices >= dcp_rank)
assert torch.any(valid), "There is no valid sequence"
# local kv lens on current dcp_rank
k_lens = torch.div(max_indices - dcp_rank,
dcp_size,
rounding_mode="floor") + 1
k_lens = torch.where(
valid,
k_lens,
torch.zeros_like(k_lens))
# vectorize operation
# obtain the max length of all prefill reqs
max_q = int(q_lens[valid].max().item())
max_k = int(k_lens[valid].max().item())
# generate local q and k indices
q_indices = torch.arange(max_q, device=self.device)
k_indices = torch.arange(max_k, device=self.device)
# valid q and k indices of each reqs
valid_q = valid[:, None] & \
(q_indices[None, :] < q_lens[:, None])
valid_k = valid[:, None] & \
(k_indices[None, :] < k_lens[:, None])
# where global q_indices >= global k_indices,
# the mask is True
# global q_indices = context_lens + local q_indices
# global k_indices = local k_indcies * dcp_size + dcp_rank
# ====> local k_indcies must be smaller or equal k_upper
# k_upper=(context_lens + local q_indices - dcp_rank) // dcp_size
k_upper = torch.div(
context_lens[:, None] + q_indices - dcp_rank,
dcp_size, rounding_mode="floor")
k_upper = torch.where(
valid_q,
torch.clamp(k_upper, min=-1),
k_upper.new_full(k_upper.shape, -1))
mask = (k_indices[None, None, :] <= k_upper[:, :, None]) \
& (k_upper[:, :, None] >= 0)
valid_positions = valid_q[:, :, None] & valid_k[:, None, :]
# flashinfer backend needs flattened format
custom_mask = torch.masked_select(mask, valid_positions)
|
Apologies for the delayed review! left a couple nits; overall its looking pretty good though |
|
This pull request has merge conflicts that must be resolved before it can be |
Hi @LucasWilkinson . Could you re-review this PR and give the final sign off ? Thanks! |
|
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Apologies for the delay! Overall looks pretty good so far but I think we should land #26696 first (seems more important and this can build on that), thoughts?
|
|
||
| self.num_qo_heads = self.model_config.get_num_attention_heads( | ||
| self.vllm_config.parallel_config | ||
| try: |
There was a problem hiding this comment.
Sorry for delay. We were blocked by #26696 which will effect our local kv lengths. Now #26696 has been merged and we can continue to improve this work.
Compared to the previous commit, we have made significant improvements and employed the similar implementation as #24864 The main reason for making this improvement is that we found that inducing a custom mask greatly slows down the prefill_wrapper.run() operator (2ms -> 10ms when seq_len=32k). For avoid the custom mask, we divide the computation of prefill stage into context and new tokens.
# |---------- context_len ----------|--- query_len---|
# |------------ context-- ----------|-- newtokens ---|
- For newtokens, query can compute with kv in
causal=Truemode without other communications in DCP group. - For context, the KV is distributed across different DCP ranks and causal mask is not required. We follows the #24864, i.e., all-gathering Q and lse, then correcting the attn out before performing reduce-scatter.
This implementation obtains the memory space by splitting kvcache and not impair performance much.
| block_table_tensor = common_attn_metadata.block_table_tensor | ||
|
|
||
| if self.dcp_world_size > 1: | ||
| seq_lens_np = seq_lens_np // self.dcp_world_size + ( |
| if self.dcp_world_size > 1: | ||
| prefill_query = get_dcp_group().all_gather( | ||
| prefill_query.contiguous(), dim=1 | ||
| ) |
There was a problem hiding this comment.
nit: I guess this is fine but I guess the name "decode context parallel" is falling apart a bit here 😞
There was a problem hiding this comment.
In last implementation, this issue does exist. But in current version, additional DCP communication operations during the prefill phase only occur when context tokens are present. So I think this may be suitable.
| ], | ||
| "bigcode/gpt_bigcode-santacoder": [ | ||
| CPTestSettings.detailed(), | ||
| CPTestSettings.detailed(tp_base=2), |
There was a problem hiding this comment.
I think it's better to keep the default backend for CI.
| class CPTestOptions(NamedTuple): | ||
| multi_node_only: bool | ||
| load_format: str | None = None | ||
| attn_backend: str = "FLASH_ATTN" |
There was a problem hiding this comment.
MLA can't use "FLASH_ATTN" backend, so the default value should not be set.
There was a problem hiding this comment.
Sure. We have improved it. Thanks for your comment.
Signed-off-by: Jingchun Gao <gaojingchun1@huawei.com>
Signed-off-by: Jingchun Gao <gaojingchun1@huawei.com>
Signed-off-by: Jingchun Gao <gaojingchun1@huawei.com>
Signed-off-by: Jingchun Gao <gaojingchun1@huawei.com>
Signed-off-by: Jingchun Gao <gaojingchun1@huawei.com>
Signed-off-by: Jingchun Gao <gaojingchun1@huawei.com>
Signed-off-by: Jingchun Gao <gaojingchun1@huawei.com>
| fixed_split_size=self.prefill_fixed_split_size, | ||
| disable_split_kv=self.disable_split_kv, | ||
| ) | ||
| kv_query_indptr_cpu = qo_indptr_cpu.clone() |
There was a problem hiding this comment.
why is a clone needed? can't we just do kv_indptr=qo_indptr_cpu.to(self.device)
There was a problem hiding this comment.
Yes. It is not needed. We verified that removing it does not affect the model precision.
| assert not isinstance(attn_metadata.prefill_wrapper, dict) | ||
| attn_metadata.prefill_wrapper.plan( | ||
| qo_indptr_cpu.to(self.device), | ||
| paged_kv_indptr_cpu.to(self.device), |
There was a problem hiding this comment.
why add .to(self.device); last i checked FlashInfer prefers CPU tensors otherwise we can get D2H copies in the plan: #21137
There was a problem hiding this comment.
This was necessary in our last implementation with the custom mask, otherwise a device error would occur in the BatchPrefillWithPagedKVCacheWrapper. Now we can freely remove it.
| BatchPrefillWithPagedKVCacheWrapper | BatchPrefillWithRaggedKVCacheWrapper, | ||
| ] | ||
| | None | ||
| ) = None |
There was a problem hiding this comment.
this type signature is complicated and repeated alot 😞 ; maybe we could make our wrapper shim? like
class BatchDCPPrefillWrapper:
self._new_tokens: BatchPrefillWithRaggedKVCacheWrapper
self._context: BatchPrefillWithPagedKVCacheWrapper
def plan(....):
self._new_tokens.plan(...)
self._context.plan(...)
def run(...):
new = self._new_tokens.run(...)
context = self._context.run(...)
return merge_attn_states(new, context)
then we can make this prefill_wrapper: BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper
thoughts?
There was a problem hiding this comment.
Thanks for you valuable comment. We added the new class BatchDCPPrefillWrapper for concise wrapper. The main improvement is here.
| @@ -679,24 +756,61 @@ def build( | |||
| attn_metadata.max_q_len_prefill = int(query_lens_prefill.max().item()) | |||
|
|
|||
| if not attn_metadata.prefill_use_trtllm: | |||
There was a problem hiding this comment.
do we need to force prefill_use_trtllm to False when DCP is enabled?
There was a problem hiding this comment.
trllm cannot return lse so it is not supported in DCP. Now we directly disable it in vllm/utils/flashinfer.py.
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Jingchun Gao <gaojingchun1@huawei.com>
Signed-off-by: Jingchun Gao <63247409+gjc0824@users.noreply.github.com>
Signed-off-by: Jingchun Gao <gaojingchun1@huawei.com>
LucasWilkinson
left a comment
There was a problem hiding this comment.
overall is looking much better thank you! what is the issue with block interleave_size > 1 ?
vllm/utils/flashinfer.py
Outdated
| # Decode context parallel is not supported | ||
| if dcp_world_size > 1: | ||
| logger.warning_once( | ||
| "Trtllm not support lse, please use flash attention or FlashInfer backend." |
There was a problem hiding this comment.
"Trtllm does not support returning LSE and as a result does not support DCP, reverting to FlashInfer"
There was a problem hiding this comment.
Thanks! We update the warning information for a clearer explanation.
When handling contexts for chunked prefill, we split contexts into chunks based on the workspace. However, the recent refactoring of the |
Signed-off-by: Jingchun Gao <gaojingchun1@huawei.com>
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
Purpose
This PR adds Decode Context Parallel (DCP) support for GQA follwing PR #23734 and PR #24864. Current implementation based on FlashInfer Attention.
FlashInfer inserts the current query KV into the cache before computation. Each query then attends to both its own KV and the context KV on the local device, with LSE applied to correct the attention outputs.
Test Plan
Qwen/Qwen3-235B-A22B
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.