Skip to content

Add expert parallelism (EP) config support for Qwen3 MoE#45436

Merged
ArthurZucker merged 11 commits into
huggingface:mainfrom
AmineDiro:qwen3-moe-ep-v2
Apr 22, 2026
Merged

Add expert parallelism (EP) config support for Qwen3 MoE#45436
ArthurZucker merged 11 commits into
huggingface:mainfrom
AmineDiro:qwen3-moe-ep-v2

Conversation

@AmineDiro

@AmineDiro AmineDiro commented Apr 14, 2026

Copy link
Copy Markdown
Member

Summary

  • Added base_model_ep_plan to Qwen3MoeConfig enabling expert parallelism via DistributedConfig(enable_expert_parallel=True)

Depends on #45473

Test plan

Tested on 8×H100 with torchrun --nproc_per_node=8 using Qwen/Qwen3-30B-A3B (128 experts, 4 KV heads):

  • TP=2, DP=4: forward ✓, backward ✓, expert sharding 128→64 ✓
  • TP=4, DP=2: forward ✓, expert sharding 128→32 ✓
  • TP=2, CP=2, DP=2: forward ✓, backward ✓, seq split 128→64 ✓
  • Logits consistent across TP ranks
  • Gradients flow through expert weights
# Example test command
torchrun --nproc_per_node=8 scripts/test_qwen3_moe_tp_ep.py \
--model_name_or_path Qwen/Qwen3-30B-A3B --tp_size 2 --cp_size 2 --seq_len 128

test file: test_qwen3_moe_tp_ep.py

Before submitting

  • I confirm that this is not a pure code agent PR.
  • Did you read the contributor
    guideline
    , Pull Request section?

Who can review?

@3outeille @ArthurZucker (distributed / TP / EP implementation)

@HuggingFaceDocBuilderDev

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@AmineDiro

Copy link
Copy Markdown
Member Author

Currently, model.tp_plan returns _ep_plan instead of _tp_plan when enable_expert_parallel=True (see modeling_utils.py ). This means the EP plan fully replaces the TP plan.

This creates a tension between attn layer TP and EP :

If we include attention entries (like Llama4 does):

base_model_ep_plan = {
  "layers.*.self_attn.q_proj": "colwise",  # ← attention TP
  ...
  "layers.*.mlp.experts.gate_up_proj": "grouped_gemm",  # ← expert EP
}
  • TP+EP on a 2D mesh works (our test: TP=2, DP=4 ✓)
  • But EP size is constrained to divide num_kv_heads if I understand correctly (Qwen3Attention.forward infers the shape (*input_shape, -1, self.head_dim). ). For Qwen3-30B-A3B (num_kv_heads=4) means the max EP=4. With EP=64 on a 1D mesh attention forward would probably crashes because q_proj output can't be reshaped into whole heads.

If we exclude attention entries (like gpt_oss does): Pure EP works at any scale (EP=16, 32, 64) BUT TP+EP on a 2D mesh loses attention sharding :/

What's the preferred approach? Should we go expert-only in the EP plan for maximum flexibility, or include attention for combined TP+EP at the cost of constraining EP size?

@3outeille

3outeille commented Apr 15, 2026

Copy link
Copy Markdown
Member

I think for now, we should create a new mesh dim ep that will use the base_model_ep_plan without any TP. And then use the base_model_tp_plan for attention layers only on the tp mesh dim

@AmineDiro AmineDiro changed the title Add expert parallelism (EP) support for Qwen3 MoE + fix GroupedGemmParallel for 2D meshes Add expert parallelism (EP) support for Qwen3 MoE Apr 20, 2026
@AmineDiro AmineDiro changed the title Add expert parallelism (EP) support for Qwen3 MoE Add expert parallelism (EP) config support for Qwen3 MoE Apr 20, 2026

@ArthurZucker ArthurZucker left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@ArthurZucker

Copy link
Copy Markdown
Collaborator

But EP size is constrained to divide num_kv_heads

It would be nice if we add a check on this some way or document this on ep plan / tp plan combination.

Recommended approach is to go with the best defaults / what makes most sense. pure EP would probably be "faster" as less coms, EP + FSDP is easier to get and probably makes more sense? but we want to allow people to have EP + TP if they want -> error early if simple!

@AmineDiro

AmineDiro commented Apr 21, 2026

Copy link
Copy Markdown
Member Author

It would be nice if we add a check on this some way or document this on ep plan / tp plan combination.
Great idea ! For now, I removed the attn layers from base_ep_plan because they limited the sharding scheme. I'll see where we can assert this post init to make if attn layer are in the base_ep_plan we assert that n_kv % ep_size ==0

pure EP would probably be "faster" as less coms, EP + FSDP is easier to get and probably makes more sense

Here are some numbers for Qwen3-30B-A3B (MoE, 128 experts, 8 active):

Context Nodes GPUs DP TP CP EP MFU TPS TPS/GPU Peak GPU Mem
16k 2 16 16 1 1 16 22.5% 60,600 3,788 71.1 GB (89%)
16k 4 32 32 1 1 32 22.3% 120,000 3,750 55.1 GB (69%)
16k 8 64 64 1 1 64 21.4% 230,500 3,602 51.1 GB (64%)
32k 2 16 8 1 2 16 13.2% 42,970 2,686 71.5 GB (89%)
32k 4 32 16 1 2 32 13.2% 85,510 2,672 55.2 GB (69%)
32k 8 64 32 1 2 64 12.9% 167,900 2,623 51.5 GB (64%)

For comparison FSDP2 without EP:

Context Nodes GPUs DP TP CP EP MFU TPS TPS/GPU
16k 2 16 16 1 1 1 23.1% 62,210 3,888
16k 4 32 32 1 1 1 22.9% 123,500 3,859
32k 2 16 8 1 2 1 13.6% 44,110 2,757
32k 4 32 16 1 2 1 13.5% 87,600 2,738

I also ran EP=8 2D mesh vs EP=16 flat mesh on 2 nodes, just to see if the layout matters. so EP =9 in 2 nodes, means 8 experts (intra node) replicated across 2 nodes.

Context EP EP mesh MFU TPS/GPU Peak GPU Mem
32k 16 flat (16) 1.9% 1,515 49.3 GB
32k 8 2D (2,8) 1.9% 1,542 49.3 GB
64k 16 flat (16) 2.9% 1,312 58.7 GB
64k 8 2D (2,8) 2.9% 1,324 58.7 GB

Flat and 2D meshes give identical performance, because I think the current EP implementation uses all-reduce (96% inter-node bandwidth), not all-to-all.

So EP seems to be slower than pure FSDP2 or FSDP2+CP for long ctx. 🤔

AmineDiro and others added 4 commits April 21, 2026 08:14
Add base_model_ep_plan to Qwen3VLMoeTextConfig
Defines sharding strategy for MoE experts without affecting attention
layers, allowing EP to scale beyond num_kv_heads constraints.
Remove duplicate base_model_ep_plan with attention entries from qwen3_moe
and update qwen3_vl_moe to use the expert-only EP plan. Attention is left
unsharded — FSDP2 handles attention weight distribution.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@ArthurZucker

Copy link
Copy Markdown
Collaborator

Yeah, this could just be because we do send 100% of all hidden states to all experts instead of sending just the ones allocated to that expert, same for when we reduce?

In any case that's good to have we can investigate later on perf issues / bottlenecks !

@ArthurZucker

Copy link
Copy Markdown
Collaborator

run-slow: qwen3_moe, qwen3_omni_moe, qwen3_vl_moe

@ArthurZucker

Copy link
Copy Markdown
Collaborator
16k 8 64 64 1 1 64 21.4% 230,500 3,602 51.1 GB (64%)

is the fastest no?

@github-actions

Copy link
Copy Markdown
Contributor

Workflow Run ⚙️

This comment contains run-slow, running the specified jobs:

models: ["models/qwen3_moe", "models/qwen3_omni_moe", "models/qwen3_vl_moe"]
quantizations: []

@github-actions

Copy link
Copy Markdown
Contributor

CI Results

Workflow Run ⚙️

Commit Info

Context Commit Description
RUN 214a3da8 workflow commit (merge commit)
PR 38181267 branch commit (from PR)
main f048e845 base commit (on main)

✅ No failing test specific to this PR 🎉 👏 !

@AmineDiro

AmineDiro commented Apr 22, 2026

Copy link
Copy Markdown
Member Author

16k 8 64 64 1 1 64 21.4% 230,500 3,602 51.1 GB (64%)
is the fastest no?

for MFU and TPS we lose perf by doubling EP size each time :/

In any case that's good to have we can investigate later on perf issues / bottlenecks !

Yes I profiled the Qwen3 30B MoE briefly to get a sense of what was slow. It seems to be the all-reduce in MoE layers 🤔

image

The all_reduce_forward and it takes a huge slice of the forward comm
image

@ArthurZucker

Copy link
Copy Markdown
Collaborator

happy to merge once green!

@ArthurZucker

Copy link
Copy Markdown
Collaborator

@bot /style

@github-actions

github-actions Bot commented Apr 22, 2026

Copy link
Copy Markdown
Contributor

Style fix bot fixed some files and pushed the changes.

@github-actions

Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: qwen3_moe, qwen3_omni_moe, qwen3_vl_moe

@AmineDiro

Copy link
Copy Markdown
Member Author

Finally CI is 🟢 !

@ArthurZucker ArthurZucker added this pull request to the merge queue Apr 22, 2026
Merged via the queue into huggingface:main with commit 7cf4241 Apr 22, 2026
22 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants