Skip to content

[Bug] DeepEP + nvfp4 issues with PD disagg or agg #12293

@kaixih

Description

@kaixih

Checklist

  • 1. I have searched related issues but cannot get the expected help.
  • 2. The bug has not been fixed in the latest version.
  • 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
  • 4. If the issue you raised is not a bug but a question, please raise a discussion at https://github.com/sgl-project/sglang/discussions/new/choose Otherwise, it will be closed.
  • 5. Please use English, otherwise it will be closed.

Describe the bug

We’re working on PD disaggregation following the instructions using the main branch (or sglang:dev container). However, we’ve encountered several issues along the way.

We’d like to use this issue page to track the remaining problems in main and clarify the best command setup for PD disagg NVFP4, since many flags seem to have changed, been deprecated, or become redundant over the past few months.

Plan:

  • PD agg works with DeepEP (BF16 dispatch so no new DeepEP is required); this is mainly a sanity check (Fix [NVIDIA] Fix cutedsl backend of MoE #12353)
  • PD disagg works with DeepEP (BF16 dispatch); also a sanity check
  • PD disagg works with DeepEP (NVFP4 dispatch) - Functionality
  • PD disagg works with DeepEP (NVFP4 dispatch) - Perf analysys
  • PD disagg fp4 best practice

PD Agg

In parallel, we also tried reverting to the PD aggregation workloads we originally used for development, but we’re observing other issues such as:

[2025-10-28 21:14:27] DataParallelController hit an exception: Traceback (most recent call last):
  File "/scratch/repo/sglang/python/sglang/srt/managers/data_parallel_controller.py", line 496, in run_data_parallel_controller_process
    controller = DataParallelController(server_args, port_args)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/repo/sglang/python/sglang/srt/managers/data_parallel_controller.py", line 151, in __init__
    self.launch_dp_attention_schedulers(server_args, port_args)
  File "/scratch/repo/sglang/python/sglang/srt/managers/data_parallel_controller.py", line 344, in launch_dp_attention_schedulers
    self.launch_tensor_parallel_group(
  File "/scratch/repo/sglang/python/sglang/srt/managers/data_parallel_controller.py", line 428, in launch_tensor_parallel_group
    scheduler_info.append(scheduler_pipe_readers[i].recv())
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/connection.py", line 250, in recv
    buf = self._recv_bytes()
          ^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/connection.py", line 430, in _recv_bytes
    buf = self._recv(4)
          ^^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/connection.py", line 399, in _recv
    raise EOFError
EOFError

Repro steps are below.

PD Disagg

We know the DeepEP low-latency NVFP4 dispatch isn’t included in the nightly build yet, so we’ve temporarily fallen back to the BF16 dispatch to test functionality. However, we’re still seeing multiple issues.

The current sbatch commands used: https://gist.github.com/kaixih/32bdc4fec4feabe9305d1acb2e1f96db

Document the first issue encountered
For example, on the 2nd prefill node, I can see:

  File "/sgl-workspace/sglang/python/sglang/srt/single_batch_overlap.py", line 85, in execute_sbo                                                                                                                                          
    _compute_overlap_args(dispatch_output, alt_stream, disable_sbo=disable_sbo)                                                                                                                                                            
  File "/sgl-workspace/sglang/python/sglang/srt/single_batch_overlap.py", line 115, in _compute_overlap_args                                                                                                                               
    num_local_experts, num_tokens_static, hidden_dim = hidden_states.shape                                                                                                                                                                 
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                       
ValueError: not enough values to unpack (expected 3, got 2)

Reproduction

model_str=/model/nvidia-DeepSeek-R1-0528-FP4
SGLANG_DEEPEP_BF16_DISPATCH=1 \
SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK=256 \
SGLANG_CUTEDSL_MOE_NVFP4_DISPATCH=0 \
python3 -m sglang.launch_server \
  --model-path $model_str \
  --trust-remote-code \
  --disable-radix-cache \
  --max-running-requests 256 \
  --chunked-prefill-size 1024 \
  --mem-fraction-static 0.89 \
  --max-prefill-tokens 16384 \
  --tp 4 \
  --ep 4 \
  --dp 4 \
  --enable-dp-attention \
  --attention-backend trtllm_mla \
  --moe-dense-tp-size 1 \
  --quantization modelopt_fp4 \
  --moe-a2a-backend deepep \
  --deepep-mode low_latency \
  --moe-runner-backend flashinfer_cutedsl

Environment

root@7edb52683dc0:/scratch/lab/flashinfer_moe# python3 -m sglang.check_env
Python: 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0]
CUDA available: True
GPU 0,1,2,3: NVIDIA GB200
GPU 0,1,2,3 Compute Capability: 10.0
CUDA_HOME: /usr/local/cuda
NVCC: Cuda compilation tools, release 12.9, V12.9.86
CUDA Driver Version: 580.82.07
PyTorch: 2.8.0+cu129
sglang: 0.5.4.post1
sgl_kernel: 0.3.16.post4
flashinfer_python: 0.4.1
triton: 3.4.0
transformers: 4.57.1
torchao: 0.9.0
numpy: 2.3.4
aiohttp: 3.13.1
fastapi: 0.120.0
hf_transfer: 0.1.9
huggingface_hub: 0.36.0
interegular: 0.3.3
modelscope: 1.31.0
orjson: 3.11.4
outlines: 0.1.11
packaging: 25.0
psutil: 7.1.2
pydantic: 2.12.3
python-multipart: 0.0.20
pyzmq: 27.1.0
uvicorn: 0.38.0
uvloop: 0.22.1
vllm: Module Not Found
xgrammar: 0.1.25
openai: 1.99.1
tiktoken: 0.12.0
anthropic: 0.71.0
litellm: Module Not Found
decord2: 2.0.0
NVIDIA Topology:
GPU0 GPU1 GPU2 GPU3 NIC0 NIC1 NIC2 NIC3 CPU Affinity NUMA Affinity GPU NUMA ID
GPU0 X NV18 NV18 NV18 NODE NODE SYS SYS 0-71 0 N/A
GPU1 NV18 X NV18 NV18 NODE NODE SYS SYS 0-71 0 N/A
GPU2 NV18 NV18 X NV18 SYS SYS NODE NODE 72-143 1 N/A
GPU3 NV18 NV18 NV18 X SYS SYS NODE NODE 72-143 1 N/A
NIC0 NODE NODE SYS SYS X PIX SYS SYS
NIC1 NODE NODE SYS SYS PIX X SYS SYS
NIC2 SYS SYS NODE NODE SYS SYS X PIX
NIC3 SYS SYS NODE NODE SYS SYS PIX X

Legend:

X = Self
SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
PIX = Connection traversing at most a single PCIe bridge
NV# = Connection traversing a bonded set of # NVLinks

NIC Legend:

NIC0: mlx5_2
NIC1: mlx5_3
NIC2: mlx5_6
NIC3: mlx5_7

ulimit soft: 1048576

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions