[Feature] DCP: Decode Context Parallelism with A2A and FA3 Backend Support#21637
Open
thanhhao98 wants to merge 16 commits intosgl-project:mainfrom
Open
[Feature] DCP: Decode Context Parallelism with A2A and FA3 Backend Support#21637thanhhao98 wants to merge 16 commits intosgl-project:mainfrom
thanhhao98 wants to merge 16 commits intosgl-project:mainfrom
Conversation
Contributor
There was a problem hiding this comment.
Code Review
This pull request introduces Decode Context Parallelism (DCP) to the SGLang runtime, enabling distributed attention computation and interleaved KV cache storage across multiple GPUs to handle long-context scenarios more efficiently. The changes include a new All-to-All communication backend with optimized Triton kernels for LSE-weighted merging, integration across FlashAttention, FlashInfer, and FlashMLA backends, and extensive updates to the DeepSeek-V2 model implementation. Feedback identifies a design concern regarding the use of mutable global state for managing attention backends and a bug in the argument parsing logic of the symmetric-memory benchmark script.
00d08ab to
106803d
Compare
thanhhao98
pushed a commit
to thanhhao98/sglang
that referenced
this pull request
Mar 30, 2026
Use explicit `shift 1` to make intent clear -- shift only the first arg (NGPU), leaving remaining args as OPS. The `2>/dev/null` is stderr redirection, not a shift count. Address review feedback from PR sgl-project#21637. Made-with: Cursor
Implement Decode Context Parallelism (DCP) to support longer context windows with TP 8 under 8xH20. DCP splits KV cache across DCP ranks using interleaved storage, allowing virtual capacity expansion of real_kv_size * dcp_world_size. Key changes: - Add DCP distributed group infrastructure and parallel state management - Implement DcpTokenToKVPoolAllocator with interleaved KV cache storage - Modify FlashInfer MLA and FlashMLA attention backends for DCP support - Add Triton kernel for attention output correction using LSE - Support all-gather for absorbed queries and reduce-scatter for outputs - Support chunked-prefill, decode cuda graph, and prefix cache with DCP - Add symmetric memory support for DCP communication operations - Add comprehensive unit tests for interleaved storage logic Compatible with: flashinfer attention backend, chunked-prefill, decode cuda graph, prefix cache, full cuda graph, PP. Not yet supported: radix-cache (partial), PD disaggregation, MTP. Squashed from 21 commits on yjh/dcp-dev-main branch. Reference: sgl-project#14194
Add DCP (Decode Context Parallelism) support for the FlashInfer backend: - allocator.py: DCP interleaved token allocation with rank-based filtering - forward_batch_info.py: missing import for create_chunked_prefix_cache_kv_indices - deepseek_v2.py: FORWARD_ABSORB_CORE_ATTENTION_BACKENDS, merge_state_v2 import - scheduler.py: DCP rank in process title Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…allelism Add DCP support for the FlashAttention (FA3) backend: - FlashAttentionMetadata: dcp_page_table and dcp_cache_seqlens for local KV shard - _init_dcp_decode_metadata: vectorized page table construction for DCP interleaving - Forward decode: DCP AllGather Q, local page table substitution, LSE normalization - CUDA graph: pre-allocated DCP buffers for graph capture/replay - deepseek_v2.py: is_base_e LSE detection and ln(2) conversion for AG+RS
Add All-to-All (A2A) as an alternative to AG+RS for DCP communication:
- pynccl.py: all_to_all_single using ncclGroupStart/End for graph capture
- parallel_state.py: GroupCoordinator.all_to_all_single method
- dcp_a2a.py: Triton LSE-weighted combine kernel, fused output+LSE buffer
- server_args.py: --dcp-comm-backend {ag_rs, a2a} flag
- deepseek_v2.py: A2A dispatch in DCP reduce block
- flashattention_backend.py: A2A dispatch, per-bs CUDA graph buffer allocation
- Tests: test_dcp_a2a.py (Triton kernel correctness), test_dcp_accuracy_matrix.py
Replace the SGLANG_DCP environment variable with a proper --dcp-size argument in ServerArgs for decode context parallelism configuration: - server_args.py: add dcp_size field and --dcp-size argparse argument - parallel_state.py: add set_dcp_size(), remove env var read from get_dcp_size_from_env(), pass dcp_size through initialize_model_parallel - model_runner.py: pass server_args.dcp_size to initialize_model_parallel - scheduler.py: use server_args.dcp_size instead of env var - engine.py: use server_args.dcp_size instead of env var - test_dcp_accuracy_matrix.py: use --dcp-size flag instead of env var
CPTritonContext was intended to cache compiled Triton kernels across calls, but ctx was never persisted — a new instance was created on every call, so the caching never took effect. Triton already caches compiled kernels internally based on constexpr values, making this wrapper redundant. Remove it and call the kernel directly.
…ersion hack The AG+RS kernel (_correct_attn_cp_out_kernel) was hardcoded to base-2 (exp2/log2), requiring callers to manually convert base-e LSE from FA3 via division by ln(2). This was error-prone — flashattention_backend.py was missing the conversion entirely for its non-MLA AG+RS path. Add IS_BASE_E constexpr to the kernel so it handles both base-e (exp/log) and base-2 (exp2/log2) natively, matching _dcp_lse_combine_kernel's approach. Propagate is_lse_base_on_e through correct_attn_out and cp_lse_ag_out_rs, and remove the lse / ln(2) hack from deepseek_v2.py.
Remove TODO comments from the original DCP commit (3c751c6) where the described functionality has since been implemented: - flashinfer_mla_backend.py: LSE return for decode (done), kv_indices update (filter_seq_indices implemented below the TODO) - deepseek_v2.py: change local_heads logic (attn_mqa_for_dcp_decode created), all_gather q_pe (implemented below), return lse and correct attn_output (DCP reduce block implemented) - model_runner.py: prepare for dcp (DCP prefix preparation implemented) - allocator.py: triton kernel for filter_local_indices (only used in offline save/load paths, PyTorch boolean indexing is sufficient)
Register DCP tests in test/registered/ for CI auto-discovery: - dcp/test_dcp_accuracy.py: E2E 8-GPU GSM8K accuracy (FlashInfer/FA3 x AG+RS/A2A) - kernels/test_dcp_lse_combine.py: Triton LSE combine vs CPU reference - kernels/test_dcp_interleaved.py: KV allocator interleaved storage - kernels/test_dcp_fa3_standalone.py: FA3 MLA simulated DCP sharding - unit/server_args/test_dcp_config.py: ServerArgs DCP validation - unit/layers/test_dcp_cascade_guard.py: cascade attention guard logic - unit/layers/test_dcp_need_lse.py: need_lse flag logic Made-with: Cursor
added 6 commits
March 30, 2026 13:03
- Remove unused imports: os/time/Optional/Tuple in benchmark_symm_mem.py, redundant DllmConfig in scheduler.py - Add noqa: F401 for re-exported create_chunked_prefix_cache_kv_indices - Fix import sorting (isort) in deepseek_v2.py, forward_batch_info.py, test_dcp_fa3_standalone.py, test_dcp_a2a.py - Apply black formatting across 14 files - Rewrite benchmark_symm_mem.py: add argparse, A2A benchmark, clean up commented code, update bench_symm.sh with H100 results Made-with: Cursor
- bench_dcp_serving.sh: automated benchmark across 6 configs (TP8 baseline, DCP8 AG+RS, DCP8 A2A) x (FlashInfer, FA3) with accuracy (GSM8K) + throughput (bench_serving) at cc1-512 - plot_dcp_bench.py: parse bench_serving outputs, generate comparison charts and markdown tables for throughput/TTFT/TPOT/ITL metrics Made-with: Cursor
The DCP cherry-pick accidentally duplicated all mixin methods inside deepseek_v2.py (3455 vs 2264 lines). Reset deepseek_v2.py to main and move DCP-specific code to the proper mixin files: - forward_mla.py: Q AllGather (decode), KV prefix gather (extend), LSE-weighted combine (AG+RS and A2A) in forward_absorb_prepare/core - forward_mha.py: DCP KV index filtering in _get_mla_kv_buffer, DCP prefix KV AllGather in forward_normal_prepare and _chunked_prefix_attn_mha - deepseek_v2.py: only DCP imports, attn_mqa_for_dcp_decode init, and _all_gather_dcp_kv_cache method - utils.py: add_forward_absorb_core_attention_backend helper Made-with: Cursor
Allocate A2A send/recv buffers inside SymmetricMemoryContext so they are graph-capturable with PyNccl. Benchmark shows 3.0-3.6x speedup for A2A collectives with symmetric memory vs torch eager. Remove --disable-cuda-graph from A2A test configs since A2A now works with CUDA graph capture. Made-with: Cursor
Use explicit `shift 1` to make intent clear -- shift only the first arg (NGPU), leaving remaining args as OPS. The `2>/dev/null` is stderr redirection, not a shift count. Address review feedback from PR sgl-project#21637. Made-with: Cursor
- Fix CPU-device torch.arange in _all_gather_dcp_kv_cache causing CPU-GPU sync on every call (add device=kv_a.device) - Add assertion H % N == 0 in dcp_a2a_lse_reduce to guard against silent wrong results with indivisible head counts - Clarify return statement in dcp_lse_combine_triton with explicit parentheses for operator precedence - Remove wasteful tl.load in Triton kernel store -- use out_ptr.dtype instead of loading from recv_output_ptr just to get the dtype Made-with: Cursor
6fd58d9 to
64566f7
Compare
Replace local_seqlens.max().item() (GPU-to-CPU sync every decode step) with a derivation from metadata.max_seq_len_k which is already on CPU. The upper bound ceil((max_seq_len_k - dcp_rank) / N) is equivalent and avoids the synchronization overhead on the hot decode path. Made-with: Cursor
15 tasks
Collaborator
|
Smaller PRs:
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
This PR extends the DCP (Decode Context Parallelism) feature from #14194 by @staugust with two major additions:
These extensions make DCP production-ready with multiple communication strategies and attention backends, giving users flexibility to choose the best combination for their hardware.
Background: What is DCP?
DCP splits the KV cache across ranks within a TP group during decode. Each rank stores only
1/dcp_sizeof the KV cache tokens (interleaved by position), computes partial attention over its local shard, then combines results using LSE-weighted merging to produce the correct full-attention output.This allows serving much longer contexts (e.g., 256K-1M tokens) on the same hardware by distributing KV cache memory across GPUs, at the cost of additional communication during decode.
The base DCP implementation in #14194 supports FlashInfer with AllGather+ReduceScatter (AG+RS) communication. This PR adds:
Modifications
1. A2A Communication Backend (
dcp_a2a.py)New file:
python/sglang/srt/layers/attention/dcp_a2a.pyInstead of AllGather(Q) -> Attention -> AllGather(output) -> LSE-correct -> ReduceScatter(output), the A2A backend:
[N, B, H_per_rank, D + lse_pack_dim])all_to_all_singleto exchange head partials between ranksThis halves NCCL collective calls per MLA layer (from 2 to 1), reducing communication overhead.
Key components:
dcp_lse_combine_triton: Triton kernel for LSE-weighted output combination (supports both base-e FA3 and base-2 FlashInfer LSE conventions)dcp_a2a_lse_reduce: Fused A2A exchange + local combine with CUDA graph buffer support_lse_weighted_combine_cpu: CPU reference implementation for testingPyNcclCommunicator.all_to_all_single: NCCL-based A2A using ncclGroupStart/End for graph-capturability2. FA3 Backend Support (
flashattention_backend.py)Extended the FlashAttention backend to handle DCP decode and extend paths:
3. Server Args and Configuration
--dcp-size N: DCP world size (replacesSGLANG_DCPenv var)--dcp-comm-backend {ag_rs, a2a}: Communication backend choicedcp_size > 1;tp_sizemust be divisible bydcp_size4. Symmetric Memory Support
pynccl_allocator.pyto support multiple symmetric memory groups (TP + DCP)SGLANG_DCP_SYMM_ONLYenv var to enable symmetric memory exclusively for DCP group5. CI-Registered Tests
Added 7 test files under
test/registered/for CI auto-discovery:dcp/test_dcp_accuracy.pystage-c-test-8-gpu-h200kernels/test_dcp_lse_combine.pystage-b-test-1-gpu-largekernels/test_dcp_interleaved.pystage-b-test-1-gpu-smallkernels/test_dcp_fa3_standalone.pystage-b-test-1-gpu-largeunit/server_args/test_dcp_config.pystage-a-test-cpuunit/layers/test_dcp_cascade_guard.pystage-a-test-cpuunit/layers/test_dcp_need_lse.pystage-a-test-cpu6. Symmetric Memory Benchmark
benchmark/kernels/all_reduce/benchmark_symm_mem.py: Benchmarks AllGather, ReduceScatter, and All-to-All collectives comparing torch eager vs PyNccl symmetric-memory CUDA graph.Benchmarking and Profiling
Serving Performance: DCP vs TP8 Baseline
Benchmarked with DeepSeek-V2 on 8x H100, using
bench_servingwith random dataset (input ~4000 tokens, output ~1500 tokens) across concurrency levels 1-512.Benchmarking comparison chart
Key finding: DCP enables 2.2x higher throughput at high concurrency by distributing KV cache across GPUs, allowing the system to serve more concurrent requests before running out of memory.
Output Token Throughput (tok/s)
Mean TTFT (ms) -- Time to First Token
Per-Token Latency (TPOT/ITL)
At low concurrency (cc1-cc32), TP8 has ~15% lower per-token latency since DCP adds communication overhead per decode step. However, TP8 cannot sustain higher concurrency at all -- it queues requests instead, making the latency comparison moot above cc64.
Accuracy Tests
GSM8K few-shot accuracy with DeepSeek-V2 on 8x H100, TP8, DCP8 (200 questions):
All DCP configurations match baseline TP8 accuracy within noise margin.
Symmetric Memory Benchmark (H100 8-GPU)
Symmetric memory CUDA graph speedup: AG 2.6-5.4x, RS 4.0-6.2x, A2A 3.0-3.6x.
Usage
Acknowledgment
This work builds on the DCP implementation by @staugust in #14194, which introduced the core DCP infrastructure: distributed group setup, interleaved KV cache storage, FlashInfer MLA backend integration, and the LSE correction kernel. Our contributions extend this foundation with A2A communication, FA3 backend support, symmetric memory optimization, proper CLI args, and comprehensive CI-registered tests.
Checklist