Skip to content

Apply should_use_dp_reduce_scatterv guard to remaining MoE models (follow-up to #23731)#23732

Merged
ByronHsu merged 2 commits intosgl-project:mainfrom
ByronHsu:fix/moe-dp-reduce-scatterv-guard-rest
Apr 26, 2026
Merged

Apply should_use_dp_reduce_scatterv guard to remaining MoE models (follow-up to #23731)#23732
ByronHsu merged 2 commits intosgl-project:mainfrom
ByronHsu:fix/moe-dp-reduce-scatterv-guard-rest

Conversation

@ByronHsu
Copy link
Copy Markdown
Collaborator

@ByronHsu ByronHsu commented Apr 25, 2026

Motivation

Follow-up to #23731 (Qwen3 MoE). Supersedes #23431 (same diff for the 12 files there) by also fixing hunyuan_v3.py.

PR #22642 introduced should_use_dp_reduce_scatterv(), which fuses the post-MoE all-reduce with dp_scatter into a single reduce_scatterv call inside LayerCommunicator. To avoid a double-reduce, the model-side tensor_model_parallel_all_reduce (or moe_*_all_reduce) on final_hidden_states must be skipped when this fast path is active.

That PR added the guard only to qwen2_moe.py. #23731 fixed qwen3_moe.py. Every other MoE model that does the same post-experts all-reduce silently double-reduces when running with DP attention + EP + moe_a2a_backend="none" — same regression pattern as #23729.

#23431 surfaced this in the nightly suite via test_minimax_m25.py variant TP8+DP8+EP8+DPAttn:

Modifications

13 files (~16 reduce sites). The 12 files from #23431 plus hunyuan_v3.py:

File Reduce sites Notes
bailing_moe.py 1 standard pattern
bailing_moe_linear.py 1 standard pattern
deepseek_v2.py 2 forward_normal + dual-stream; forward_cpu intentionally untouched (CPU path doesn't trigger the fast path)
exaone_moe.py 1 standard pattern
glm4_moe.py 2 forward_normal + dual-stream
hunyuan_v3.py 2 uses moe_expert_parallel_all_reduce + moe_tensor_model_parallel_all_reduce (same shape as qwen3_moe.py); both branches gated
llada2.py 1 standard pattern
llama4.py 1 standard pattern
mimo_v2_flash.py 1 standard pattern
minimax_m2.py 1 standard pattern
sarvam_moe.py 2 forward_normal + dual-stream
sdar_moe.py 1 standard pattern
step3p5.py 1 standard pattern

Each and not should_use_flashinfer_cutlass_moe_fp4_allgather() guard gets a sibling and not should_use_dp_reduce_scatterv() line, matching the pattern from qwen2_moe.py and qwen3_moe.py. hunyuan_v3.py does not have the fp4 guard, so a skip_post_reduce = should_use_dp_reduce_scatterv() local short-circuits both reduces.

Validation

Checklist

cc @YAMY1234 (PR #22642 author)

Refs #23729 #23731 #23431

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the should_use_dp_reduce_scatterv() guard across several Mixture-of-Experts (MoE) model implementations, including Bailing, DeepSeek-V2, GLM-4, Hunyuan-V3, MIMO-V2, MiniMax-M2, Sarvam, SDAR, and Step3.5. This guard is integrated into the forward pass logic to conditionally skip the final tensor model parallel or expert parallel all-reduce operations when a fused reduction is expected to be handled by an external communicator. I have no feedback to provide.

@Kangyan-Zhou
Copy link
Copy Markdown
Collaborator

Kangyan-Zhou commented Apr 25, 2026

/tag-and-rerun-ci again

Follow-up to sgl-project#23731 (Qwen3 MoE) — PR sgl-project#22642 introduced
should_use_dp_reduce_scatterv() to fuse the post-MoE all-reduce with
dp_scatter into a single reduce_scatterv inside LayerCommunicator, but
only patched qwen2_moe.py to skip the model-side
tensor_model_parallel_all_reduce when the fast path is active. Every
other MoE model that does the same post-experts all-reduce double-
reduces under DP attention + EP, exactly as Qwen3 did. Reported in
sgl-project#23431 with a real GSM8K nightly: 0.951 pre-sgl-project#22642 → 0.002–0.010 post →
0.980 with the guard.

Mirror the guard onto the affected MoE models:

- bailing_moe.py
- bailing_moe_linear.py
- deepseek_v2.py (forward_normal + dual-stream variant; forward_cpu
  intentionally untouched since CPU path doesn't trigger the fast path)
- exaone_moe.py
- glm4_moe.py (both forward_normal and dual-stream)
- hunyuan_v3.py (uses moe_expert_parallel_all_reduce +
  moe_tensor_model_parallel_all_reduce like qwen3_moe; both branches
  must be skipped when the fast path is active)
- llada2.py
- llama4.py
- mimo_v2_flash.py
- minimax_m2.py
- sarvam_moe.py (forward_normal + dual-stream)
- sdar_moe.py
- step3p5.py

Each file gains the same one-line `and not should_use_dp_reduce_scatterv()`
guard alongside the existing `should_use_flashinfer_cutlass_moe_fp4_allgather`
guard (or its equivalent), matching the pattern used in qwen2_moe.py and
qwen3_moe.py. Supersedes sgl-project#23431 (same diff for the 12 files there) and
adds hunyuan_v3.py.

Refs sgl-project#23729 sgl-project#23731 sgl-project#23431

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@ByronHsu ByronHsu force-pushed the fix/moe-dp-reduce-scatterv-guard-rest branch from 4a1fbbb to 9cbd0f8 Compare April 25, 2026 22:50
@ByronHsu ByronHsu merged commit ba4e9d2 into sgl-project:main Apr 26, 2026
195 of 212 checks passed
@hnyls2002 hnyls2002 mentioned this pull request Apr 29, 2026
vguduruTT pushed a commit to vguduruTT/sglang that referenced this pull request May 2, 2026
…llow-up to sgl-project#23731) (sgl-project#23732)

Co-authored-by: Byron Hsu <byronhsu@noreply.github.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-authored-by: Kangyan-Zhou <zky314343421@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants