Summary
Implement Distributed Weight Data Parallelism (DWDP) in SGLang — a parallelism strategy that distributes MoE expert weights across GPUs within a node while keeping attention weights fully replicated. Instead of using collective synchronization (AllReduce/AllGather), DWDP uses asynchronous peer-to-peer prefetches to pull remote expert weights before they are needed, eliminating synchronization barriers from the critical path.
Inspired by NVIDIA's TensorRT-LLM DWDP implementation (blog, PR #12136), which demonstrated 8.8% output TPS/GPU improvement at comparable TPS/user in the 20-100 TPS/user serving range under 8k/1k ISL/OSL and 14.26% iteration latency reduction on DeepSeek-R1 with GB200 NVL72.
Motivation
The Problem with Tensor Parallelism for MoE
Current MoE inference in SGLang uses Tensor Parallelism (TP) or Expert Parallelism (EP), both requiring collective communication (e.g., AllReduce, AllGather) at synchronization barriers. As model sizes grow and batch sizes vary, these synchronization points become the bottleneck:
- Small batches: Compute finishes quickly but all ranks wait at the barrier.
- Imbalanced expert routing: Some ranks finish early but stall waiting for stragglers.
- NVLink bandwidth underutilized: High-bandwidth peer links (NVL72: 1.8 TB/s bisection) are used only in bursts during collectives rather than continuously.
The DWDP Opportunity
DWDP replaces blocking collectives with asynchronous weight prefetches via copy engine:
- Each rank executes its expert shards locally (no waiting for other ranks).
- Missing remote expert weights are fetched asynchronously from peers before they are needed (prefetch-ahead).
- Prefetches use the copy engine (independent of SM compute), allowing overlap with GEMM computation.
- Result: ranks proceed independently, NVLink bandwidth is used continuously, and the critical path has no synchronization cost.
This is especially impactful for MoE models (DeepSeek-R1/V3, Mixtral, Qwen-MoE) on multi-GPU nodes with NVLink (e.g., DGX H100, GB200 NVL72).
Proposed Design
Parallelism Layout
DWDP Group (e.g., 4 GPUs)
┌──────────────────┬──────────────────┬──────────────────┬──────────────────┐
│ GPU 0 │ GPU 1 │ GPU 2 │ GPU 3 │
├──────────────────┴──────────────────┴──────────────────┴──────────────────┤
│ Attention: fully replicated on every GPU │
├──────────────────┬──────────────────┬──────────────────┬──────────────────┤
│ MoE shard 0 │ MoE shard 1 │ MoE shard 2 │ MoE shard 3 │
│ experts [0, N/4)│ experts[N/4,N/2)│ experts[N/2,3N/4)│ experts [3N/4,N)│
├──────────────────┴──────────────────┴──────────────────┴──────────────────┤
│ Each GPU runs full data-parallel forward; missing expert weights are │
│ prefetched asynchronously from peers via P2P copy engine │
└───────────────────────────────────────────────────────────────────────────┘
- Attention layers: fully replicated (same as DP), no communication needed.
- MoE expert weights: partitioned across DWDP group GPUs (each GPU holds 1/N of experts).
- Expert execution: each token is routed to its top-K experts; if an expert lives on a remote GPU, its weights are fetched via peer copy before the GEMM.
- No collective ops in the forward critical path.
Key Components to Implement
1. DWDPConfig
Configuration for DWDP parallelism:
@dataclass
class DWDPConfig:
group_size: int # Number of GPUs in a DWDP group
num_local_experts: int # Experts per GPU = total_experts / group_size
prefetch_depth: int = 2 # How many layers ahead to prefetch
2. Expert Weight Partitioning
- Partition MoE expert weight tensors across GPUs at load time.
- Register CUDA IPC handles for peer access within the DWDP group.
- Each GPU maintains a mapping:
expert_id → (peer_rank, local_offset).
3. DWDPPrefetchBuffer
Double-buffered staging buffers for remote expert weights:
- Ping-pong buffers: while layer N computes with buffer A, layer N+1's weights are fetched into buffer B.
- Uses dedicated CUDA streams for copy engine transfers (separate from compute stream).
- Stream synchronization ensures weights arrive before GEMM launch.
4. MoE Forward Pass Integration
Modify the MoE forward to:
- Trigger async prefetch for top-K expert weights not in local memory ahead of time. The timing to trigger the prefetch is a design choice.
- Synchronize prefetch stream before executing GroupGemm in MoE.
5. Grouped GEMM with Split Weights
Extend the grouped GEMM kernel to accept a TensorList (list of weight tensor pointers) rather than a single contiguous weight tensor. This allows direct consumption of prefetched buffers without an extra device-to-device copy to merge them.
Note the requires changes to the FlashInfer CuteDSL MoE kernels.
Communication Contention Mitigation
- Use round-robin scheduling of fixed-size transfer slices across destination ranks to avoid hot-spot contention on the source copy engine.
- Expected gain: ~8% additional throughput when compute windows are short (small batch).
Performance Targets
Based on TRT-LLM results on GB200 NVL72 with DeepSeek-R1 (8K input, 1K output):
| Metric |
Baseline (EP/TP) |
With DWDP |
Delta |
| Output TPS/GPU (TPS/user 20–100) |
1.00x |
1.088x |
+8.8% |
| Iteration latency (context-only) |
1.00x |
0.857x |
−14.3% |
Scope and Constraints
Initial Scope (v1)
- Target hardware: NVLink-connected multi-GPU nodes (B200, GB200, B300, GB300).
- Target models: MoE models with high expert count (DeepSeek-R1 nvfp4 first).
- Quantization: FP8 / NVFP4 (weight-only quant friendly; avoids large FP16 prefetch volumes).
- Requires
TP=1 within each DWDP group (data parallel across groups).
- CUDA IPC peer access (single-node only in v1).
Out of Scope (v1)
- Cross-node DWDP (requires RDMA/network transport instead of CUDA IPC).
- Compatibility with the overlap scheduler — TRT-LLM lists this as incompatible; needs investigation for SGLang's implementation.
- Dynamic expert load balancing (EPLB) within DWDP group.
Future Extensions
- Cross-node DWDP: replace CUDA IPC with UCX/NCCL P2P for multi-node.
- Virtual memory weight assembly: map expert shards into a contiguous virtual address range to eliminate kernel specialization for split-weight GEMMs.
- DWDP + EP hybrid: use EP across nodes, DWDP within a node.
- Integration with EPLB: allow dynamic expert migration within DWDP group.
- Compare perf vs CPU weight offloading in SGLang
References
Related resources
No response
Summary
Implement Distributed Weight Data Parallelism (DWDP) in SGLang — a parallelism strategy that distributes MoE expert weights across GPUs within a node while keeping attention weights fully replicated. Instead of using collective synchronization (AllReduce/AllGather), DWDP uses asynchronous peer-to-peer prefetches to pull remote expert weights before they are needed, eliminating synchronization barriers from the critical path.
Inspired by NVIDIA's TensorRT-LLM DWDP implementation (blog, PR #12136), which demonstrated 8.8% output TPS/GPU improvement at comparable TPS/user in the 20-100 TPS/user serving range under 8k/1k ISL/OSL and 14.26% iteration latency reduction on DeepSeek-R1 with GB200 NVL72.
Motivation
The Problem with Tensor Parallelism for MoE
Current MoE inference in SGLang uses Tensor Parallelism (TP) or Expert Parallelism (EP), both requiring collective communication (e.g., AllReduce, AllGather) at synchronization barriers. As model sizes grow and batch sizes vary, these synchronization points become the bottleneck:
The DWDP Opportunity
DWDP replaces blocking collectives with asynchronous weight prefetches via copy engine:
This is especially impactful for MoE models (DeepSeek-R1/V3, Mixtral, Qwen-MoE) on multi-GPU nodes with NVLink (e.g., DGX H100, GB200 NVL72).
Proposed Design
Parallelism Layout
Key Components to Implement
1.
DWDPConfigConfiguration for DWDP parallelism:
2. Expert Weight Partitioning
expert_id → (peer_rank, local_offset).3.
DWDPPrefetchBufferDouble-buffered staging buffers for remote expert weights:
4. MoE Forward Pass Integration
Modify the MoE forward to:
5. Grouped GEMM with Split Weights
Extend the grouped GEMM kernel to accept a
TensorList(list of weight tensor pointers) rather than a single contiguous weight tensor. This allows direct consumption of prefetched buffers without an extra device-to-device copy to merge them.Note the requires changes to the FlashInfer CuteDSL MoE kernels.
Communication Contention Mitigation
Performance Targets
Based on TRT-LLM results on GB200 NVL72 with DeepSeek-R1 (8K input, 1K output):
Scope and Constraints
Initial Scope (v1)
TP=1within each DWDP group (data parallel across groups).Out of Scope (v1)
Future Extensions
References
Related resources
No response