Skip to content

[Feature] DCP: Decode Context Parallelism with A2A and FA3 Backend Support#21637

Open
thanhhao98 wants to merge 16 commits intosgl-project:mainfrom
thanhhao98:htphan/dcp-helix
Open

[Feature] DCP: Decode Context Parallelism with A2A and FA3 Backend Support#21637
thanhhao98 wants to merge 16 commits intosgl-project:mainfrom
thanhhao98:htphan/dcp-helix

Conversation

@thanhhao98
Copy link
Copy Markdown

@thanhhao98 thanhhao98 commented Mar 29, 2026

Motivation

This PR extends the DCP (Decode Context Parallelism) feature from #14194 by @staugust with two major additions:

  1. A2A (All-to-All) communication backend -- an alternative to AllGather+ReduceScatter that reduces NCCL calls per layer from 2 to 1 by fusing output+LSE into a single exchange.
  2. FA3 (FlashAttention-3) backend support -- enables DCP with the FA3 attention backend, not just FlashInfer.

These extensions make DCP production-ready with multiple communication strategies and attention backends, giving users flexibility to choose the best combination for their hardware.

Background: What is DCP?

DCP splits the KV cache across ranks within a TP group during decode. Each rank stores only 1/dcp_size of the KV cache tokens (interleaved by position), computes partial attention over its local shard, then combines results using LSE-weighted merging to produce the correct full-attention output.

This allows serving much longer contexts (e.g., 256K-1M tokens) on the same hardware by distributing KV cache memory across GPUs, at the cost of additional communication during decode.

The base DCP implementation in #14194 supports FlashInfer with AllGather+ReduceScatter (AG+RS) communication. This PR adds:

Modifications

1. A2A Communication Backend (dcp_a2a.py)

New file: python/sglang/srt/layers/attention/dcp_a2a.py

Instead of AllGather(Q) -> Attention -> AllGather(output) -> LSE-correct -> ReduceScatter(output), the A2A backend:

  • Runs attention with all heads on local KV shard (heads already distributed by TP)
  • Packs output + LSE into a fused buffer ([N, B, H_per_rank, D + lse_pack_dim])
  • Executes a single all_to_all_single to exchange head partials between ranks
  • Combines received partials with a Triton LSE-weighted combine kernel

This halves NCCL collective calls per MLA layer (from 2 to 1), reducing communication overhead.

Key components:

  • dcp_lse_combine_triton: Triton kernel for LSE-weighted output combination (supports both base-e FA3 and base-2 FlashInfer LSE conventions)
  • dcp_a2a_lse_reduce: Fused A2A exchange + local combine with CUDA graph buffer support
  • _lse_weighted_combine_cpu: CPU reference implementation for testing
  • PyNcclCommunicator.all_to_all_single: NCCL-based A2A using ncclGroupStart/End for graph-capturability

2. FA3 Backend Support (flashattention_backend.py)

Extended the FlashAttention backend to handle DCP decode and extend paths:

  • DCP metadata computation (head counts, group references)
  • Q AllGather across DCP group for decode
  • KV prefix AllGather for extend
  • LSE-weighted output correction after attention (AG+RS or A2A)
  • Cascade attention guard when DCP > 1
  • CUDA graph buffer pre-allocation for DCP A2A

3. Server Args and Configuration

  • --dcp-size N: DCP world size (replaces SGLANG_DCP env var)
  • --dcp-comm-backend {ag_rs, a2a}: Communication backend choice
  • Validation: A2A requires dcp_size > 1; tp_size must be divisible by dcp_size

4. Symmetric Memory Support

  • Extended pynccl_allocator.py to support multiple symmetric memory groups (TP + DCP)
  • DCP group uses symmetric memory for AllGather/ReduceScatter under CUDA graph capture
  • SGLANG_DCP_SYMM_ONLY env var to enable symmetric memory exclusively for DCP group

5. CI-Registered Tests

Added 7 test files under test/registered/ for CI auto-discovery:

File Suite Tests
dcp/test_dcp_accuracy.py stage-c-test-8-gpu-h200 4 E2E configs (FlashInfer/FA3 x AG+RS/A2A)
kernels/test_dcp_lse_combine.py stage-b-test-1-gpu-large 21 Triton kernel correctness tests
kernels/test_dcp_interleaved.py stage-b-test-1-gpu-small 11 KV allocator interleaved storage tests
kernels/test_dcp_fa3_standalone.py stage-b-test-1-gpu-large 6 FA3 MLA + simulated DCP tests
unit/server_args/test_dcp_config.py stage-a-test-cpu 8 ServerArgs validation tests
unit/layers/test_dcp_cascade_guard.py stage-a-test-cpu 15 cascade attention guard tests
unit/layers/test_dcp_need_lse.py stage-a-test-cpu 4 need_lse logic tests

6. Symmetric Memory Benchmark

benchmark/kernels/all_reduce/benchmark_symm_mem.py: Benchmarks AllGather, ReduceScatter, and All-to-All collectives comparing torch eager vs PyNccl symmetric-memory CUDA graph.

Benchmarking and Profiling

Serving Performance: DCP vs TP8 Baseline

Benchmarked with DeepSeek-V2 on 8x H100, using bench_serving with random dataset (input ~4000 tokens, output ~1500 tokens) across concurrency levels 1-512.

Benchmarking comparison chart

bench_comparison

Key finding: DCP enables 2.2x higher throughput at high concurrency by distributing KV cache across GPUs, allowing the system to serve more concurrent requests before running out of memory.

Output Token Throughput (tok/s)

Config cc1 cc8 cc32 cc64 cc128 cc256 cc512
TP8 FlashInfer (baseline) 103 498 1060 1385 1374 1403 1400
TP8 FA3 (baseline) 103 492 1058 1370 1358 1384 1382
DCP8 AG+RS FlashInfer 89 425 947 1336 1945 2568 3107
DCP8 AG+RS FA3 90 425 953 1358 1961 2572 2930
DCP8 A2A FlashInfer 86 413 929 1320 1919 2559 3126
DCP8 A2A FA3 87 413 936 1341 1933 2570 2951
  • TP8 plateaus at ~1400 tok/s around cc64 -- KV cache is full, no more requests can be served concurrently.
  • DCP8 continues scaling to 2900-3100 tok/s at cc512 -- 8x more KV cache capacity from interleaved distribution.
  • At cc512: DCP delivers 2.2x the throughput of TP8.

Mean TTFT (ms) -- Time to First Token

Config cc32 cc64 cc128 cc256 cc512
TP8 FlashInfer 413 5,744 39,672 106,844 239,577
DCP8 AG+RS FlashInfer 420 655 1,095 1,970 28,734
  • TP8 TTFT explodes to 240 seconds at cc512 due to request queuing when KV cache is full.
  • DCP8 TTFT stays at 28 seconds -- a 8.3x improvement because more requests fit in memory.

Per-Token Latency (TPOT/ITL)

At low concurrency (cc1-cc32), TP8 has ~15% lower per-token latency since DCP adds communication overhead per decode step. However, TP8 cannot sustain higher concurrency at all -- it queues requests instead, making the latency comparison moot above cc64.

Accuracy Tests

GSM8K few-shot accuracy with DeepSeek-V2 on 8x H100, TP8, DCP8 (200 questions):

Configuration Accuracy
TP8 (baseline, no DCP) 0.805-0.810
DCP8 FlashInfer + AG+RS 0.810
DCP8 FlashInfer + A2A 0.800
DCP8 FA3 + AG+RS 0.790
DCP8 FA3 + A2A 0.800

All DCP configurations match baseline TP8 accuracy within noise margin.

Symmetric Memory Benchmark (H100 8-GPU)

msg_size AG eager (us) AG symm graph (us) RS eager (us) RS symm graph (us) A2A eager (us) A2A symm graph (us)
2 KiB 14.57 2.68 16.06 2.82 18.34 5.45
8 KiB 14.42 3.11 15.82 3.00 17.84 5.57
32 KiB 14.37 5.07 18.74 3.25 17.90 6.25
128 KiB 18.14 6.90 17.12 4.24 21.18 7.00

Symmetric memory CUDA graph speedup: AG 2.6-5.4x, RS 4.0-6.2x, A2A 3.0-3.6x.

Usage

# DCP with AG+RS (default, compatible with CUDA graph)
python -m sglang.launch_server \
    --model-path deepseek-ai/DeepSeek-V2 \
    --tp-size 8 --dcp-size 8 \
    --dcp-comm-backend ag_rs \
    --attention-backend flashinfer \
    --enable-symm-mem --disable-radix-cache \
    --trust-remote-code

# DCP with A2A communication
SGLANG_DCP_SYMM_ONLY=true python -m sglang.launch_server \
    --model-path deepseek-ai/DeepSeek-V2 \
    --tp-size 8 --dcp-size 8 \
    --dcp-comm-backend a2a \
    --attention-backend fa3 \
    --enable-symm-mem --disable-radix-cache \
    --trust-remote-code

# DCP with ag_rs communication
SGLANG_DCP_SYMM_ONLY=true python -m sglang.launch_server \
    --model-path deepseek-ai/DeepSeek-V2 \
    --tp-size 8 --dcp-size 8 \
    --dcp-comm-backend ag_rs \
    --attention-backend fa3 \
    --enable-symm-mem --disable-radix-cache \
    --trust-remote-code

Acknowledgment

This work builds on the DCP implementation by @staugust in #14194, which introduced the core DCP infrastructure: distributed group setup, interleaved KV cache storage, FlashInfer MLA backend integration, and the LSE correction kernel. Our contributions extend this foundation with A2A communication, FA3 backend support, symmetric memory optimization, proper CLI args, and comprehensive CI-registered tests.

Checklist

  • Format code with pre-commit
  • Add unit tests (69 tests across 7 files, all passing)
  • Accuracy benchmark (GSM8K: 0.79-0.81 across all configs)
  • Performance benchmark (symmetric memory collectives)
  • Follow SGLang code style

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces Decode Context Parallelism (DCP) to the SGLang runtime, enabling distributed attention computation and interleaved KV cache storage across multiple GPUs to handle long-context scenarios more efficiently. The changes include a new All-to-All communication backend with optimized Triton kernels for LSE-weighted merging, integration across FlashAttention, FlashInfer, and FlashMLA backends, and extensive updates to the DeepSeek-V2 model implementation. Feedback identifies a design concern regarding the use of mutable global state for managing attention backends and a bug in the argument parsing logic of the symmetric-memory benchmark script.

Comment thread python/sglang/srt/models/deepseek_v2.py Outdated
Comment thread benchmark/kernels/all_reduce/bench_symm.sh Outdated
thanhhao98 pushed a commit to thanhhao98/sglang that referenced this pull request Mar 30, 2026
Use explicit `shift 1` to make intent clear -- shift only the first
arg (NGPU), leaving remaining args as OPS. The `2>/dev/null` is stderr
redirection, not a shift count. Address review feedback from PR sgl-project#21637.

Made-with: Cursor
staugust and others added 9 commits March 30, 2026 13:03
Implement Decode Context Parallelism (DCP) to support longer context
windows with TP 8 under 8xH20. DCP splits KV cache across DCP ranks
using interleaved storage, allowing virtual capacity expansion of
real_kv_size * dcp_world_size.

Key changes:
- Add DCP distributed group infrastructure and parallel state management
- Implement DcpTokenToKVPoolAllocator with interleaved KV cache storage
- Modify FlashInfer MLA and FlashMLA attention backends for DCP support
- Add Triton kernel for attention output correction using LSE
- Support all-gather for absorbed queries and reduce-scatter for outputs
- Support chunked-prefill, decode cuda graph, and prefix cache with DCP
- Add symmetric memory support for DCP communication operations
- Add comprehensive unit tests for interleaved storage logic

Compatible with: flashinfer attention backend, chunked-prefill,
decode cuda graph, prefix cache, full cuda graph, PP.
Not yet supported: radix-cache (partial), PD disaggregation, MTP.

Squashed from 21 commits on yjh/dcp-dev-main branch.
Reference: sgl-project#14194
Add DCP (Decode Context Parallelism) support for the FlashInfer backend:
- allocator.py: DCP interleaved token allocation with rank-based filtering
- forward_batch_info.py: missing import for create_chunked_prefix_cache_kv_indices
- deepseek_v2.py: FORWARD_ABSORB_CORE_ATTENTION_BACKENDS, merge_state_v2 import
- scheduler.py: DCP rank in process title

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…allelism

Add DCP support for the FlashAttention (FA3) backend:
- FlashAttentionMetadata: dcp_page_table and dcp_cache_seqlens for local KV shard
- _init_dcp_decode_metadata: vectorized page table construction for DCP interleaving
- Forward decode: DCP AllGather Q, local page table substitution, LSE normalization
- CUDA graph: pre-allocated DCP buffers for graph capture/replay
- deepseek_v2.py: is_base_e LSE detection and ln(2) conversion for AG+RS
Add All-to-All (A2A) as an alternative to AG+RS for DCP communication:
- pynccl.py: all_to_all_single using ncclGroupStart/End for graph capture
- parallel_state.py: GroupCoordinator.all_to_all_single method
- dcp_a2a.py: Triton LSE-weighted combine kernel, fused output+LSE buffer
- server_args.py: --dcp-comm-backend {ag_rs, a2a} flag
- deepseek_v2.py: A2A dispatch in DCP reduce block
- flashattention_backend.py: A2A dispatch, per-bs CUDA graph buffer allocation
- Tests: test_dcp_a2a.py (Triton kernel correctness), test_dcp_accuracy_matrix.py
Replace the SGLANG_DCP environment variable with a proper --dcp-size
argument in ServerArgs for decode context parallelism configuration:
- server_args.py: add dcp_size field and --dcp-size argparse argument
- parallel_state.py: add set_dcp_size(), remove env var read from
  get_dcp_size_from_env(), pass dcp_size through initialize_model_parallel
- model_runner.py: pass server_args.dcp_size to initialize_model_parallel
- scheduler.py: use server_args.dcp_size instead of env var
- engine.py: use server_args.dcp_size instead of env var
- test_dcp_accuracy_matrix.py: use --dcp-size flag instead of env var
CPTritonContext was intended to cache compiled Triton kernels across
calls, but ctx was never persisted — a new instance was created on
every call, so the caching never took effect. Triton already caches
compiled kernels internally based on constexpr values, making this
wrapper redundant. Remove it and call the kernel directly.
…ersion hack

The AG+RS kernel (_correct_attn_cp_out_kernel) was hardcoded to base-2
(exp2/log2), requiring callers to manually convert base-e LSE from FA3
via division by ln(2). This was error-prone — flashattention_backend.py
was missing the conversion entirely for its non-MLA AG+RS path.

Add IS_BASE_E constexpr to the kernel so it handles both base-e (exp/log)
and base-2 (exp2/log2) natively, matching _dcp_lse_combine_kernel's
approach. Propagate is_lse_base_on_e through correct_attn_out and
cp_lse_ag_out_rs, and remove the lse / ln(2) hack from deepseek_v2.py.
Remove TODO comments from the original DCP commit (3c751c6) where the
described functionality has since been implemented:
- flashinfer_mla_backend.py: LSE return for decode (done), kv_indices
  update (filter_seq_indices implemented below the TODO)
- deepseek_v2.py: change local_heads logic (attn_mqa_for_dcp_decode
  created), all_gather q_pe (implemented below), return lse and correct
  attn_output (DCP reduce block implemented)
- model_runner.py: prepare for dcp (DCP prefix preparation implemented)
- allocator.py: triton kernel for filter_local_indices (only used in
  offline save/load paths, PyTorch boolean indexing is sufficient)
Register DCP tests in test/registered/ for CI auto-discovery:
- dcp/test_dcp_accuracy.py: E2E 8-GPU GSM8K accuracy (FlashInfer/FA3 x AG+RS/A2A)
- kernels/test_dcp_lse_combine.py: Triton LSE combine vs CPU reference
- kernels/test_dcp_interleaved.py: KV allocator interleaved storage
- kernels/test_dcp_fa3_standalone.py: FA3 MLA simulated DCP sharding
- unit/server_args/test_dcp_config.py: ServerArgs DCP validation
- unit/layers/test_dcp_cascade_guard.py: cascade attention guard logic
- unit/layers/test_dcp_need_lse.py: need_lse flag logic

Made-with: Cursor
Hao Phan added 6 commits March 30, 2026 13:03
- Remove unused imports: os/time/Optional/Tuple in benchmark_symm_mem.py,
  redundant DllmConfig in scheduler.py
- Add noqa: F401 for re-exported create_chunked_prefix_cache_kv_indices
- Fix import sorting (isort) in deepseek_v2.py, forward_batch_info.py,
  test_dcp_fa3_standalone.py, test_dcp_a2a.py
- Apply black formatting across 14 files
- Rewrite benchmark_symm_mem.py: add argparse, A2A benchmark, clean up
  commented code, update bench_symm.sh with H100 results

Made-with: Cursor
- bench_dcp_serving.sh: automated benchmark across 6 configs
  (TP8 baseline, DCP8 AG+RS, DCP8 A2A) x (FlashInfer, FA3) with
  accuracy (GSM8K) + throughput (bench_serving) at cc1-512
- plot_dcp_bench.py: parse bench_serving outputs, generate comparison
  charts and markdown tables for throughput/TTFT/TPOT/ITL metrics

Made-with: Cursor
The DCP cherry-pick accidentally duplicated all mixin methods inside
deepseek_v2.py (3455 vs 2264 lines). Reset deepseek_v2.py to main and
move DCP-specific code to the proper mixin files:

- forward_mla.py: Q AllGather (decode), KV prefix gather (extend),
  LSE-weighted combine (AG+RS and A2A) in forward_absorb_prepare/core
- forward_mha.py: DCP KV index filtering in _get_mla_kv_buffer, DCP
  prefix KV AllGather in forward_normal_prepare and _chunked_prefix_attn_mha
- deepseek_v2.py: only DCP imports, attn_mqa_for_dcp_decode init,
  and _all_gather_dcp_kv_cache method
- utils.py: add_forward_absorb_core_attention_backend helper

Made-with: Cursor
Allocate A2A send/recv buffers inside SymmetricMemoryContext so they
are graph-capturable with PyNccl. Benchmark shows 3.0-3.6x speedup
for A2A collectives with symmetric memory vs torch eager.

Remove --disable-cuda-graph from A2A test configs since A2A now
works with CUDA graph capture.

Made-with: Cursor
Use explicit `shift 1` to make intent clear -- shift only the first
arg (NGPU), leaving remaining args as OPS. The `2>/dev/null` is stderr
redirection, not a shift count. Address review feedback from PR sgl-project#21637.

Made-with: Cursor
- Fix CPU-device torch.arange in _all_gather_dcp_kv_cache causing
  CPU-GPU sync on every call (add device=kv_a.device)
- Add assertion H % N == 0 in dcp_a2a_lse_reduce to guard against
  silent wrong results with indivisible head counts
- Clarify return statement in dcp_lse_combine_triton with explicit
  parentheses for operator precedence
- Remove wasteful tl.load in Triton kernel store -- use out_ptr.dtype
  instead of loading from recv_output_ptr just to get the dtype

Made-with: Cursor
Replace local_seqlens.max().item() (GPU-to-CPU sync every decode step)
with a derivation from metadata.max_seq_len_k which is already on CPU.
The upper bound ceil((max_seq_len_k - dcp_rank) / N) is equivalent and
avoids the synchronization overhead on the hot decode path.

Made-with: Cursor
@nvpohanh
Copy link
Copy Markdown
Collaborator

nvpohanh commented Apr 9, 2026

Smaller PRs:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants