Checklist
Motivation
DCP (#12196) partitions KV cache along the sequence dimension to enable long-context decode. Active PRs (#14982, #14194, #18167) implement this with AllGather+ReduceScatter (AG+RS) communication. This proposal extends that infrastructure with two improvements from Helix:
-
A2A communication backend — All-to-All as an alternative post-attention pattern. Both AG+RS and A2A use 3 NCCL ops/layer, but A2A replaces network AllReduce with a local Triton combine kernel, reducing latency at long contexts.
-
Decoupled attention/FFN parallelism — Today, a single --tp-size controls both attention and FFN sharding. When TP > num_kv_heads, KV cache is replicated (8× for MLA with TP=8, 2× for GQA-8 with TP=16). Decoupling allows attention to use fewer GPUs (attn_tp_size ≤ num_kv_heads), eliminating KV replication while FFN uses the full TP group.
Benchmark (vLLM, GB200 NVL72 16-GPU, DeepSeek-V2-Lite) — from PR #34883:
| Context |
Concurrency |
Pure TP |
DCP (AG+RS) |
DCP (A2A) |
| 256K |
8 |
22.29 ms |
15.36 ms |
14.27 ms (-36% vs TP) |
| 512K |
8 |
33.77 ms |
15.97 ms |
15.05 ms (-55% vs TP) |
| 1M |
8 |
160.92 ms |
17.57 ms |
16.64 ms (-90% vs TP) |
A2A consistently outperforms AG+RS by 3–8% TPOT. At 1M tokens, DCP reduces TPOT by up to 90% vs pure TP.
Proposed Parameters
Building on the --dcp-size parameter from existing DCP PRs:
| Parameter |
Type |
Default |
Description |
--dcp-comm-backend |
{ag_rs, a2a} |
ag_rs |
Post-attention DCP communication pattern |
--dcp-replicate-q-proj |
bool |
false |
Replicate Q projection weights to eliminate AllGather Q (3→2 NCCL ops/layer) |
--attention-tensor-parallel-size |
int |
same as TP |
Attention head sharding parallelism (alias: --attn-tp-size) |
Naming follows SGLang conventions: --attention-tensor-parallel-size / --attn-tp-size pairs with the existing --attention-context-parallel-size / --attn-cp-size. The --dcp- prefix extends the namespace from existing DCP PRs.
Usage Examples
# A2A communication backend (on top of existing DCP)
python -m sglang.launch_server --model deepseek-ai/DeepSeek-V2-Lite --tp 8 \
--dcp-size 8 --dcp-comm-backend a2a
# Decoupled attention/FFN (attention on 2 GPUs, FFN on 8 — no KV replication)
python -m sglang.launch_server --model deepseek-ai/DeepSeek-V2-Lite --tp 8 \
--attn-tp-size 2
All defaults preserve current behavior.
High-Level Proposal
-
A2A communication backend:
- Add
all_to_all_single() to PyNccl/GroupCoordinator (ncclSend/Recv pattern)
- Port Triton LSE combine kernel for exact softmax reconstruction from partial attention outputs
- Dispatch in attention backends based on
--dcp-comm-backend
- No model code changes — A2A is encapsulated in the attention backend
-
Q-proj replication (--dcp-replicate-q-proj):
- Trade redundant Q projection compute for eliminating AllGather Q (reduces NCCL ops from 3→2/layer)
- Applies to both AG+RS and A2A backends
- Overhead is modest for MLA (compact absorbed Q weights)
-
Decoupled attention/FFN parallelism (--attn-tp-size):
- Introduce
_ATTN_TP and _KVP process groups alongside existing _TP
- Attention shards heads across
attn_tp_size GPUs; FFN uses full tp_size group
- KVP (=
tp_size / attn_tp_size) provides the DCP sequence sharding dimension
- No changes to FFN, scheduler, or KV cache pool API
Roadmap
This work builds on the existing DCP infrastructure from PRs #14982, #14194, and #18167. We plan to split the implementation into small PRs to make review easier:
The first PR (A2A backend) depends on the existing DCP PRs being merged. Q-proj replication and decoupled parallelism can follow as independent PRs.
Future:
Related resources
Checklist
Motivation
DCP (#12196) partitions KV cache along the sequence dimension to enable long-context decode. Active PRs (#14982, #14194, #18167) implement this with AllGather+ReduceScatter (AG+RS) communication. This proposal extends that infrastructure with two improvements from Helix:
A2A communication backend — All-to-All as an alternative post-attention pattern. Both AG+RS and A2A use 3 NCCL ops/layer, but A2A replaces network AllReduce with a local Triton combine kernel, reducing latency at long contexts.
Decoupled attention/FFN parallelism — Today, a single
--tp-sizecontrols both attention and FFN sharding. WhenTP > num_kv_heads, KV cache is replicated (8× for MLA with TP=8, 2× for GQA-8 with TP=16). Decoupling allows attention to use fewer GPUs (attn_tp_size ≤ num_kv_heads), eliminating KV replication while FFN uses the full TP group.Benchmark (vLLM, GB200 NVL72 16-GPU, DeepSeek-V2-Lite) — from PR #34883:
A2A consistently outperforms AG+RS by 3–8% TPOT. At 1M tokens, DCP reduces TPOT by up to 90% vs pure TP.
Proposed Parameters
Building on the
--dcp-sizeparameter from existing DCP PRs:--dcp-comm-backend{ag_rs, a2a}ag_rs--dcp-replicate-q-projboolfalse--attention-tensor-parallel-sizeint--attn-tp-size)Naming follows SGLang conventions:
--attention-tensor-parallel-size/--attn-tp-sizepairs with the existing--attention-context-parallel-size/--attn-cp-size. The--dcp-prefix extends the namespace from existing DCP PRs.Usage Examples
All defaults preserve current behavior.
High-Level Proposal
A2A communication backend:
all_to_all_single()to PyNccl/GroupCoordinator (ncclSend/Recv pattern)--dcp-comm-backendQ-proj replication (
--dcp-replicate-q-proj):Decoupled attention/FFN parallelism (
--attn-tp-size):_ATTN_TPand_KVPprocess groups alongside existing_TPattn_tp_sizeGPUs; FFN uses fulltp_sizegrouptp_size / attn_tp_size) provides the DCP sequence sharding dimensionRoadmap
This work builds on the existing DCP infrastructure from PRs #14982, #14194, and #18167. We plan to split the implementation into small PRs to make review easier:
--dcp-comm-backend a2a, PyNccl A2A, Triton LSE combine, backend dispatch)--dcp-replicate-q-proj)--attn-tp-size,_ATTN_TP/_KVPprocess groups)The first PR (A2A backend) depends on the existing DCP PRs being merged. Q-proj replication and decoupled parallelism can follow as independent PRs.
Future:
Related resources