Fix Qwen3 MoE double-reduce when DP attention + EP + reduce_scatterv (#23729)#23731
Conversation
PR sgl-project#22642 introduced should_use_dp_reduce_scatterv() to fuse the post-MoE all-reduce with the dp_scatter into a single reduce_scatterv inside LayerCommunicator. The qwen2_moe.py forward path was patched to skip the explicit tensor_model_parallel_all_reduce when this fast path is active, but qwen3_moe.py was missed. As a result, Qwen3 MoE models running with DP attention + EP=DP (e.g. --tp 2 --dp 2 --ep 2 --enable-dp-attention, no --moe-a2a-backend) double- reduce the MoE output: once explicitly via moe_expert_parallel_all_reduce in forward_normal, then again inside reduce_scatterv from the communicator. The output is silently corrupted; the model still produces fluent text but logprobs differ from a tp-only baseline by 0.5–2 nats. Repro: see sgl-project#23729 — same prompt, temperature 0, two servers tp=2 vs tp=2 dp=2 ep=2 dp_attention. Pre-fix the two configurations diverge at the second sampled token with max |Δlogprob|=2.03; post-fix they agree to 100/100 tokens with max |Δlogprob|=0.28 (within float drift). Mirror the qwen2_moe.py guard onto both reduce branches in Qwen3MoeSparseMoeBlock.forward_normal. Fixes sgl-project#23729
There was a problem hiding this comment.
Code Review
This pull request integrates the should_use_dp_reduce_scatterv check into the qwen3_moe model's forward pass to conditionally skip all-reduce operations. The review feedback suggests also incorporating the use_reduce_scatter flag check within the expert parallel block to ensure consistency with the tensor parallel logic and prevent redundant reductions.
| if ( | ||
| self.ep_size > 1 | ||
| and not should_allreduce_fusion | ||
| and not should_use_dp_reduce_scatterv() | ||
| ): |
There was a problem hiding this comment.
The expert parallel (EP) all-reduce block should also check the use_reduce_scatter flag, mirroring the logic used in the tensor parallel (TP) block at line 345. This ensures that redundant EP all-reduces are skipped when the communicator is configured to perform a reduce-scatter operation instead of an all-reduce.
| if ( | |
| self.ep_size > 1 | |
| and not should_allreduce_fusion | |
| and not should_use_dp_reduce_scatterv() | |
| ): | |
| if ( | |
| self.ep_size > 1 | |
| and not should_allreduce_fusion | |
| and not use_reduce_scatter | |
| and not should_use_dp_reduce_scatterv() | |
| ): |
Follow-up to sgl-project#23731 (Qwen3 MoE) — PR sgl-project#22642 introduced should_use_dp_reduce_scatterv() to fuse the post-MoE all-reduce with dp_scatter into a single reduce_scatterv inside LayerCommunicator, but only patched qwen2_moe.py to skip the model-side tensor_model_parallel_all_reduce when the fast path is active. Every other MoE model that does the same post-experts all-reduce double- reduces under DP attention + EP, exactly as Qwen3 did. Reported in sgl-project#23431 with a real GSM8K nightly: 0.951 pre-sgl-project#22642 → 0.002–0.010 post → 0.980 with the guard. Mirror the guard onto the affected MoE models: - bailing_moe.py - bailing_moe_linear.py - deepseek_v2.py (forward_normal + dual-stream variant; forward_cpu intentionally untouched since CPU path doesn't trigger the fast path) - exaone_moe.py - glm4_moe.py (both forward_normal and dual-stream) - hunyuan_v3.py (uses moe_expert_parallel_all_reduce + moe_tensor_model_parallel_all_reduce like qwen3_moe; both branches must be skipped when the fast path is active) - llada2.py - llama4.py - mimo_v2_flash.py - minimax_m2.py - sarvam_moe.py (forward_normal + dual-stream) - sdar_moe.py - step3p5.py Each file gains the same one-line `and not should_use_dp_reduce_scatterv()` guard alongside the existing `should_use_flashinfer_cutlass_moe_fp4_allgather` guard (or its equivalent), matching the pattern used in qwen2_moe.py and qwen3_moe.py. Supersedes sgl-project#23431 (same diff for the 12 files there) and adds hunyuan_v3.py. Refs sgl-project#23729 sgl-project#23731 sgl-project#23431
Follow-up to sgl-project#23731 (Qwen3 MoE) — PR sgl-project#22642 introduced should_use_dp_reduce_scatterv() to fuse the post-MoE all-reduce with dp_scatter into a single reduce_scatterv inside LayerCommunicator, but only patched qwen2_moe.py to skip the model-side tensor_model_parallel_all_reduce when the fast path is active. Every other MoE model that does the same post-experts all-reduce double- reduces under DP attention + EP, exactly as Qwen3 did. Reported in sgl-project#23431 with a real GSM8K nightly: 0.951 pre-sgl-project#22642 → 0.002–0.010 post → 0.980 with the guard. Mirror the guard onto the affected MoE models: - bailing_moe.py - bailing_moe_linear.py - deepseek_v2.py (forward_normal + dual-stream variant; forward_cpu intentionally untouched since CPU path doesn't trigger the fast path) - exaone_moe.py - glm4_moe.py (both forward_normal and dual-stream) - hunyuan_v3.py (uses moe_expert_parallel_all_reduce + moe_tensor_model_parallel_all_reduce like qwen3_moe; both branches must be skipped when the fast path is active) - llada2.py - llama4.py - mimo_v2_flash.py - minimax_m2.py - sarvam_moe.py (forward_normal + dual-stream) - sdar_moe.py - step3p5.py Each file gains the same one-line `and not should_use_dp_reduce_scatterv()` guard alongside the existing `should_use_flashinfer_cutlass_moe_fp4_allgather` guard (or its equivalent), matching the pattern used in qwen2_moe.py and qwen3_moe.py. Supersedes sgl-project#23431 (same diff for the 12 files there) and adds hunyuan_v3.py. Refs sgl-project#23729 sgl-project#23731 sgl-project#23431 Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…gl-project#23729) (sgl-project#23731) Co-authored-by: Byron Hsu <byronhsu@noreply.github.com>
…follow-up to sgl-project#23731) (sgl-project#23734) Co-authored-by: Byron Hsu <byron@periodiclabs.ai> Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…llow-up to sgl-project#23731) (sgl-project#23732) Co-authored-by: Byron Hsu <byronhsu@noreply.github.com> Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Co-authored-by: Kangyan-Zhou <zky314343421@gmail.com>
Motivation
PR #22642 introduced
should_use_dp_reduce_scatterv()which fuses the post-MoE all-reduce with the dp_scatter into a singlereduce_scattervinsideLayerCommunicator. Theqwen2_moe.pyforward path was patched to skip the explicittensor_model_parallel_all_reducewhen the fast path is active — butqwen3_moe.pywas missed.As a result, Qwen3 MoE models running with DP attention + EP (e.g.
--tp 2 --dp 2 --ep 2 --enable-dp-attention, no--moe-a2a-backend) double-reduce the MoE output: once explicitly viamoe_expert_parallel_all_reduce(andmoe_tensor_model_parallel_all_reduce) inforward_normal, then again insidereduce_scattervfrom the communicator. The output is silently corrupted; the model still produces fluent text but logprobs differ from a tp-only baseline by 0.5–2 nats.Reported as #23729.
Reproduction (no DP attention vs DP+EP, same model, same prompt, temperature 0)
v0.5.9 is also clean (76/100 prefix, max |Δlp|=0.147 — pure float drift), confirming the regression was introduced after that release.
A and D agreement is now within normal float drift, matching the level of agreement between any of the (TP-only, DP-attn-only, EP-only) configurations on this model.
Modifications
In
Qwen3MoeSparseMoeBlock.forward_normal, mirror theqwen2_moe.pyguard onto both the EP and TP all-reduce branches: skip them whenshould_use_dp_reduce_scatterv()returns True (the communicator now performs the reduce insidereduce_scatterv).Accuracy
The fix only deletes redundant collectives along the affected branch; numerics on every other path are identical. Validated above on Qwen3-30B-A3B-Instruct-2507 in the broken configuration; configurations that didn't trigger
should_use_dp_reduce_scatterv()(TP-only, DP-attn alone, EP alone) are unaffected.Checklist
cc @YAMY1234 (PR #22642 author)
Fixes #23729