Checklist
Describe the bug
When running a bf16 MoE model (Qwen3-30B-A3B-Instruct-2507) on 2 GPUs with the combination
--tp-size 2 --dp-size 2 --ep-size 2 --enable-dp-attention
(no --moe-a2a-backend, so the standard fused-MoE path is used), the per-token output logprobs differ dramatically from a --tp-size 2-only baseline on the same model, same prompt, same temperature=0. The two configurations diverge at the second sampled token and produce completely different generations.
tp=2 + dp=2 + dp_attention (no EP) and tp=2 + ep=2 (no DP-attention) each match the tp=2-only baseline closely. Only the simultaneous combination of DP-attention + EP>1 (with moe_a2a_backend="none") produces the wrong logits.
With the same combination plus --moe-a2a-backend deepep, the run instead crashes at cuda-graph capture with AssertionError: forward_deepgemm_masked is deprecated from sglang/srt/layers/moe/ep_moe/layer.py:248, because that DeepEP-LL kernel only supports fp8/w4afp8 quantized weights — so there is no working path for bf16 + DP-attention + EP>1 today.
Reproduction
Four servers, same model, same prompt:
A — tp2_only (TP only)
python -m sglang.launch_server \
--model-path Qwen/Qwen3-30B-A3B-Instruct-2507 \
--tp-size 2 --trust-remote-code \
--mem-fraction-static 0.8 --watchdog-timeout 900.0 \
--host 0.0.0.0 --port 30000
B — dp2_only (TP + DP-attention, no EP)
python -m sglang.launch_server \
--model-path Qwen/Qwen3-30B-A3B-Instruct-2507 \
--tp-size 2 --dp-size 2 \
--trust-remote-code --enable-dp-attention \
--mem-fraction-static 0.8 --chunked-prefill-size 16384 \
--cuda-graph-max-bs 128 --watchdog-timeout 900.0 \
--host 0.0.0.0 --port 30000
C — ep2_only (TP + EP, no DP-attention)
python -m sglang.launch_server \
--model-path Qwen/Qwen3-30B-A3B-Instruct-2507 \
--tp-size 2 --ep-size 2 \
--trust-remote-code \
--mem-fraction-static 0.8 --watchdog-timeout 900.0 \
--host 0.0.0.0 --port 30000
D — agg_ep2_2 (TP + DP-attention + EP — the broken combination)
python -m sglang.launch_server \
--model-path Qwen/Qwen3-30B-A3B-Instruct-2507 \
--tp-size 2 --dp-size 2 --ep-size 2 \
--trust-remote-code --enable-dp-attention \
--mem-fraction-static 0.8 --chunked-prefill-size 16384 \
--cuda-graph-max-bs 128 --watchdog-timeout 900.0 \
--host 0.0.0.0 --port 30000
Same prompt, deterministic decode, request logprobs:
curl -sS -X POST http://127.0.0.1:30000/generate \
-H 'Content-Type: application/json' \
-d '{"text":"Write a science fiction for me please.","sampling_params":{"temperature":0,"max_new_tokens":100,"top_p":1.0},"return_logprob":true,"logprob_start_len":0}'
Observed: 4-recipe diff (sglang 0.5.11.dev704+g0c826374a)
Greedy-token agreement vs A=tp2_only baseline, plus mean / max per-token logprob delta:
| Recipe |
common-prefix vs A |
mean |Δlp| |
max |Δlp| |
result |
A tp2_only |
100/100 |
0.0000 |
0.0000 |
baseline |
B dp2_only |
13/100 |
0.0195 |
0.0553 |
matches A (numerics drift only) |
C ep2_only |
30/100 |
0.0250 |
0.0838 |
matches A (numerics drift only) |
D agg_ep2_2 |
2/100 |
1.2910 |
2.0324 |
broken |
A/B/C agree to ~0.1 nats; their argmax divergences after long matching prefixes are normal float-precision drift. D diverges in the second token with logprob errors >2 nats.
First 8 tokens, side-by-side (token_id logprob):
idx | tp2_only | dp2_only | ep2_only | agg_ep2_2
-----+----------------+----------------+----------------+----------------
0 | 576 -1.7573 | 576 -1.7191 | 576 -1.6951 | 576 -2.3070
1 | 3364 -0.4295 | 3364 -0.4637 | 3364 -0.4584 | 3364 -2.4619
2 | 1265 -0.4827 | 1265 -0.5017 | 1265 -0.5247 | 374 -1.1477
3 | 387 -0.3660 | 387 -0.3912 | 387 -0.3691 | 738 -2.2023
4 | 911 -1.1981 | 911 -1.2022 | 911 -1.2135 | 304 -0.6975
5 | 264 -0.3998 | 264 -0.3982 | 264 -0.3897 | 279 -1.3938
6 | 883 -2.3949 | 883 -2.4022 | 883 -2.4065 | 1042 -1.3395
7 | 6941 -0.7540 | 6941 -0.7547 | 6941 -0.7464 | 220 -0.0274
Even on the two tokens where D picks the same greedy argmax as A/B/C, D's logprobs differ by 0.55 / 2.03 nats (probability ratio ~1.7×–7.6×). After token 2 the argmax diverges entirely.
Generated text (first 120 chars):
A_tp2_only: ' The story should be about a man named David who is a scientist working on a time travel project. He is trying to go bac'
B_dp2_only: ' The story should be about a man named David who is a scientist who discovers a way to travel through time. He is workin'
C_ep2_only: ' The story should be about a man named David who is a scientist working on a time travel project. He is trying to go bac'
D_agg_ep2_2: ' The story is set in the year 2098. The story is about a man named David, a 38-year-old man who is a former deep-sea exp'
A/B/C produce essentially the same story; D produces a qualitatively different one.
Hypothesis
In LayerCommunicator, when moe_a2a_backend == "none" and the layer is sparse, _compute_mlp_mode returns ScatterMode.FULL, so an all-gather is supposed to bring DP-sliced tokens to a full per-rank set before MoE; then qwen3_moe.forward_normal does an EP all-reduce after self.experts(...). With attn_tp_size=1 (forced by tp/dp = 2/2) and ep_size=2, the gather/scatter pair around the EP all-reduce appears to be misaligning per-rank token sets, so the EP all-reduce sums partial expert outputs that don't correspond to the same logical tokens. This is consistent with each individual axis (TP, DP-attn, EP) working correctly in isolation.
Environment
Python: 3.12.3
CUDA available: True
GPU 0–7: NVIDIA H200, Compute Capability: 9.0
NVCC: 12.9, V12.9.86
CUDA Driver Version: 580.126.09
PyTorch: 2.9.1+cu130
sglang: 0.5.11.dev704+g0c826374a (latest main as of 2026-04-25)
sglang-kernel: 0.4.1+cu130
flashinfer_python: 0.6.8.post1
flashinfer_cubin: 0.6.8.post1
triton: 3.5.1
transformers: 5.5.4
torchao: 0.17.0
Update: regression bisect
This is a regression. v0.5.9 is clean; current main is broken.
| sglang |
A vs D common-prefix |
mean |Δlp| |
max |Δlp| |
A and D text |
v0.5.9 (bbe9c7eeb) |
76 / 100 |
0.026 |
0.147 |
identical |
| 0.5.11.dev704+g0c826374a |
2 / 100 |
1.29 |
2.03 |
diverge |
On v0.5.9 the per-token logprob delta between A (tp2_only) and D (agg_ep2_2) is within normal float drift, and both produce the exact same first-120-char generation. On current main the delta blows up to >2 nats by token 1 and the generations are qualitatively different.
So the bug was introduced somewhere between v0.5.9 and the current top of main. Likely candidates: changes to LayerCommunicator, qwen3_moe.forward_normal, EPMoE.forward_impl, or parallel_state group computation involving attn_tp_size / moe_ep_size / moe_tp_size.
Checklist
Describe the bug
When running a bf16 MoE model (Qwen3-30B-A3B-Instruct-2507) on 2 GPUs with the combination
(no
--moe-a2a-backend, so the standard fused-MoE path is used), the per-token output logprobs differ dramatically from a--tp-size 2-only baseline on the same model, same prompt, same temperature=0. The two configurations diverge at the second sampled token and produce completely different generations.tp=2 + dp=2 + dp_attention(no EP) andtp=2 + ep=2(no DP-attention) each match thetp=2-only baseline closely. Only the simultaneous combination of DP-attention + EP>1 (withmoe_a2a_backend="none") produces the wrong logits.With the same combination plus
--moe-a2a-backend deepep, the run instead crashes at cuda-graph capture withAssertionError: forward_deepgemm_masked is deprecatedfromsglang/srt/layers/moe/ep_moe/layer.py:248, because that DeepEP-LL kernel only supports fp8/w4afp8 quantized weights — so there is no working path for bf16 + DP-attention + EP>1 today.Reproduction
Four servers, same model, same prompt:
A —
tp2_only(TP only)B —
dp2_only(TP + DP-attention, no EP)C —
ep2_only(TP + EP, no DP-attention)D —
agg_ep2_2(TP + DP-attention + EP — the broken combination)Same prompt, deterministic decode, request logprobs:
Observed: 4-recipe diff (sglang
0.5.11.dev704+g0c826374a)Greedy-token agreement vs A=
tp2_onlybaseline, plus mean / max per-token logprob delta:tp2_onlydp2_onlyep2_onlyagg_ep2_2A/B/C agree to ~0.1 nats; their argmax divergences after long matching prefixes are normal float-precision drift. D diverges in the second token with logprob errors >2 nats.
First 8 tokens, side-by-side (
token_id logprob):Even on the two tokens where D picks the same greedy argmax as A/B/C, D's logprobs differ by 0.55 / 2.03 nats (probability ratio ~1.7×–7.6×). After token 2 the argmax diverges entirely.
Generated text (first 120 chars):
A/B/C produce essentially the same story; D produces a qualitatively different one.
Hypothesis
In
LayerCommunicator, whenmoe_a2a_backend == "none"and the layer is sparse,_compute_mlp_modereturnsScatterMode.FULL, so an all-gather is supposed to bring DP-sliced tokens to a full per-rank set before MoE; thenqwen3_moe.forward_normaldoes an EP all-reduce afterself.experts(...). Withattn_tp_size=1(forced bytp/dp = 2/2) andep_size=2, the gather/scatter pair around the EP all-reduce appears to be misaligning per-rank token sets, so the EP all-reduce sums partial expert outputs that don't correspond to the same logical tokens. This is consistent with each individual axis (TP, DP-attn, EP) working correctly in isolation.Environment
Update: regression bisect
This is a regression. v0.5.9 is clean; current main is broken.
bbe9c7eeb)On v0.5.9 the per-token logprob delta between A (
tp2_only) and D (agg_ep2_2) is within normal float drift, and both produce the exact same first-120-char generation. On current main the delta blows up to >2 nats by token 1 and the generations are qualitatively different.So the bug was introduced somewhere between v0.5.9 and the current top of
main. Likely candidates: changes toLayerCommunicator,qwen3_moe.forward_normal,EPMoE.forward_impl, orparallel_stategroup computation involvingattn_tp_size/moe_ep_size/moe_tp_size.