Skip to content

Fix DP-Attention reduce_scatterv missing guard in MiniMax/Bailing MoE#23431

Closed
Kangyan-Zhou wants to merge 2 commits intosgl-project:mainfrom
Kangyan-Zhou:fix/dp-reduce-scatterv-guard-missing-models
Closed

Fix DP-Attention reduce_scatterv missing guard in MiniMax/Bailing MoE#23431
Kangyan-Zhou wants to merge 2 commits intosgl-project:mainfrom
Kangyan-Zhou:fix/dp-reduce-scatterv-guard-missing-models

Conversation

@Kangyan-Zhou
Copy link
Copy Markdown
Collaborator

Summary

  • Replace all-reduce + dp_scatter with reduce_scatterv for DP attention #22642 switched the DP-Attention + EP post-MoE combine from all-reduce + dp_scatter to a single reduce_scatterv, and added the required and not should_use_dp_reduce_scatterv() guard to the per-model MoE all-reduce — only in qwen2_moe.py.
  • Other MoE models (minimax_m2.py, bailing_moe.py, bailing_moe_linear.py) still call tensor_model_parallel_all_reduce on the MoE output when the new reduce_scatterv path is selected in communicator.py, causing a double reduction. Hidden states are corrupted and propagate NaN/OOB values that either produce garbage outputs or trigger CUDA illegal memory accesses in downstream kernels.
  • Mirror the qwen2_moe.py fix in the three affected models (import + one condition).

Impact

Surfaced in the nightly suite via test/registered/8-gpu-models/test_minimax_m25.py variant TP8+DP8+EP8+DPAttn:

Verification

Reproduced locally on H200 4×GPU (scaled-down TP=4 DP=4 EP=4, all should_use_dp_reduce_scatterv() conditions still satisfied):

GSM8K (100 ex) Latency Notes
Before 0.060 470 s scheduler crashed mid-run with CUDA illegal memory access in fused MoE kernel
After 0.980 84 s clean, matches non-DPAttn baseline 0.956

Test plan

  • Nightly nightly-test-general-8-gpu-{h200,b200} (1) passes the TP8+DP8+EP8+DPAttn variant of test_minimax_m25.py.
  • Bailing MoE / Bailing MoE linear unchanged for non-DPAttn configs (guard is a strictly narrower condition).

🤖 Generated with Claude Code

PR sgl-project#22642 switched the DP-Attention + EP post-MoE combine from
all-reduce + dp_scatter to a single reduce_scatterv. It added the
required `and not should_use_dp_reduce_scatterv()` guard to the
per-model MoE all-reduce only in qwen2_moe.py.

Other MoE models (minimax_m2.py, bailing_moe.py, bailing_moe_linear.py)
still call tensor_model_parallel_all_reduce on the MoE output even when
the new reduce_scatterv path is selected in communicator.py, causing a
double reduction. Hidden states are corrupted and propagate NaN/OOB
values that either produce garbage outputs (GSM8K ~0.006-0.06 vs
baseline 0.80) or trigger CUDA illegal memory accesses in downstream
kernels.

Reproduced on MiniMaxAI/MiniMax-M2.5 with TP=4 DP=4 EP=4 + DP-Attn:
  before: score 0.060, scheduler crash (illegal memory access)
  after:  score 0.980, clean run

Add the same guard to the three affected models.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@Kangyan-Zhou Kangyan-Zhou marked this pull request as ready for review April 22, 2026 03:32
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@Kangyan-Zhou Kangyan-Zhou requested a review from Fridge003 April 22, 2026 03:32
Follow-up to the previous commit. Every other MoE model that calls
tensor_model_parallel_all_reduce on the MoE output conditioned on
`not use_reduce_scatter` has the same double-reduction bug when
DP-Attention + EP selects the reduce_scatterv combine path:
LayerCommunicator.should_use_reduce_scatter() only returns True for the
legacy max-len reduce-scatter case, so `use_reduce_scatter` stays False
and the model-level all_reduce fires after communicator.py has already
reduced via reduce_scatterv.

Add `and not should_use_dp_reduce_scatterv()` to each site:
  - deepseek_v2.py       (forward_absorb_fused_mla_rope_appendix_v2,
                          forward_normal)
  - glm4_moe.py          (forward_absorb_fused_mla_rope_appendix_v2,
                          forward_normal)
  - llama4.py            (Llama4MoE.forward)
  - sdar_moe.py          (forward_normal)
  - sarvam_moe.py        (forward_absorb_fused_mla_rope_appendix_v2,
                          forward_normal, and the dense decoder-layer
                          all_reduce at attn_tp_size > 1)
  - llada2.py            (forward_normal)
  - mimo_v2_flash.py     (forward_normal)
  - exaone_moe.py        (ExaoneMoE.forward)
  - step3p5.py           (Step3p5MoE.forward_normal and the decoder
                          layer's combined moe + share_expert all_reduce)

qwen3_moe.py is intentionally left alone: it calls
moe_tensor_model_parallel_all_reduce (moe-tp group), not the full tp
group used by reduce_scatterv. In the ep_size == attn_dp_size
configuration that activates reduce_scatterv, moe_tp_size is typically
1 and the call is a no-op.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
ByronHsu pushed a commit to ByronHsu/sglang that referenced this pull request Apr 25, 2026
Follow-up to sgl-project#23731 (Qwen3 MoE) — PR sgl-project#22642 introduced
should_use_dp_reduce_scatterv() to fuse the post-MoE all-reduce with
dp_scatter into a single reduce_scatterv inside LayerCommunicator, but
only patched qwen2_moe.py to skip the model-side
tensor_model_parallel_all_reduce when the fast path is active. Every
other MoE model that does the same post-experts all-reduce double-
reduces under DP attention + EP, exactly as Qwen3 did. Reported in
sgl-project#23431 with a real GSM8K nightly: 0.951 pre-sgl-project#22642 → 0.002–0.010 post →
0.980 with the guard.

Mirror the guard onto the affected MoE models:

- bailing_moe.py
- bailing_moe_linear.py
- deepseek_v2.py (forward_normal + dual-stream variant; forward_cpu
  intentionally untouched since CPU path doesn't trigger the fast path)
- exaone_moe.py
- glm4_moe.py (both forward_normal and dual-stream)
- hunyuan_v3.py (uses moe_expert_parallel_all_reduce +
  moe_tensor_model_parallel_all_reduce like qwen3_moe; both branches
  must be skipped when the fast path is active)
- llada2.py
- llama4.py
- mimo_v2_flash.py
- minimax_m2.py
- sarvam_moe.py (forward_normal + dual-stream)
- sdar_moe.py
- step3p5.py

Each file gains the same one-line `and not should_use_dp_reduce_scatterv()`
guard alongside the existing `should_use_flashinfer_cutlass_moe_fp4_allgather`
guard (or its equivalent), matching the pattern used in qwen2_moe.py and
qwen3_moe.py. Supersedes sgl-project#23431 (same diff for the 12 files there) and
adds hunyuan_v3.py.

Refs sgl-project#23729 sgl-project#23731 sgl-project#23431
@ByronHsu
Copy link
Copy Markdown
Collaborator

Thanks for catching this — the same regression also bites Qwen3 MoE (#23729) and hunyuan_v3.py, which I came across independently while bisecting an RL ESS regression on Qwen3-30B-A3B.

I've consolidated this PR's diff (verbatim, all 12 files) plus hunyuan_v3.py into #23732, so it's a strict superset of this one. Qwen3 MoE is split out separately into #23731 since it has a smaller standalone repro. Happy to close mine if you'd rather keep #23431 as the canonical fix and add hunyuan_v3.py here — whichever the maintainers prefer.

cc @YAMY1234

@Kangyan-Zhou
Copy link
Copy Markdown
Collaborator Author

Thanks for catching this — the same regression also bites Qwen3 MoE (#23729) and hunyuan_v3.py, which I came across independently while bisecting an RL ESS regression on Qwen3-30B-A3B.

I've consolidated this PR's diff (verbatim, all 12 files) plus hunyuan_v3.py into #23732, so it's a strict superset of this one. Qwen3 MoE is split out separately into #23731 since it has a smaller standalone repro. Happy to close mine if you'd rather keep #23431 as the canonical fix and add hunyuan_v3.py here — whichever the maintainers prefer.

cc @YAMY1234

cool i'll close mine

ByronHsu pushed a commit to ByronHsu/sglang that referenced this pull request Apr 25, 2026
Follow-up to sgl-project#23731 (Qwen3 MoE) — PR sgl-project#22642 introduced
should_use_dp_reduce_scatterv() to fuse the post-MoE all-reduce with
dp_scatter into a single reduce_scatterv inside LayerCommunicator, but
only patched qwen2_moe.py to skip the model-side
tensor_model_parallel_all_reduce when the fast path is active. Every
other MoE model that does the same post-experts all-reduce double-
reduces under DP attention + EP, exactly as Qwen3 did. Reported in
sgl-project#23431 with a real GSM8K nightly: 0.951 pre-sgl-project#22642 → 0.002–0.010 post →
0.980 with the guard.

Mirror the guard onto the affected MoE models:

- bailing_moe.py
- bailing_moe_linear.py
- deepseek_v2.py (forward_normal + dual-stream variant; forward_cpu
  intentionally untouched since CPU path doesn't trigger the fast path)
- exaone_moe.py
- glm4_moe.py (both forward_normal and dual-stream)
- hunyuan_v3.py (uses moe_expert_parallel_all_reduce +
  moe_tensor_model_parallel_all_reduce like qwen3_moe; both branches
  must be skipped when the fast path is active)
- llada2.py
- llama4.py
- mimo_v2_flash.py
- minimax_m2.py
- sarvam_moe.py (forward_normal + dual-stream)
- sdar_moe.py
- step3p5.py

Each file gains the same one-line `and not should_use_dp_reduce_scatterv()`
guard alongside the existing `should_use_flashinfer_cutlass_moe_fp4_allgather`
guard (or its equivalent), matching the pattern used in qwen2_moe.py and
qwen3_moe.py. Supersedes sgl-project#23431 (same diff for the 12 files there) and
adds hunyuan_v3.py.

Refs sgl-project#23729 sgl-project#23731 sgl-project#23431

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants