Skip to content

Fix EP: RouterParallel shape, tp_plan property, grouped_mm sentinels#45473

Merged
ArthurZucker merged 7 commits into
huggingface:mainfrom
AmineDiro:fixed-ep-model-parallel
Apr 21, 2026
Merged

Fix EP: RouterParallel shape, tp_plan property, grouped_mm sentinels#45473
ArthurZucker merged 7 commits into
huggingface:mainfrom
AmineDiro:fixed-ep-model-parallel

Conversation

@AmineDiro

@AmineDiro AmineDiro commented Apr 16, 2026

Copy link
Copy Markdown
Member

While benchmarking Qwen3-30B-A3B SFT training with Expert Parallelism (EP) using TRL, I found three bugs that combine to produce silently wrong results or NaN loss. Every existing test uses tp_plan="auto" which bypasses RouterParallel and EP sentinel handling entirely if I understand correctly.

These fixes are related to #45436 (will update with the expert-only EP config once this lands 👍🏼 )

Bug 1: RouterParallel breaks score/index pairing (tensor_parallel.py)

The router produces three element-wise paired tensors: router_logits (seq, num_experts), router_scores (seq, top_k), router_indices (seq, top_k). _prepare_output_fn is supposed to remap these to local expert space per EP rank.

The old code did scatter scores into a full (seq, num_experts) matrix then sliced to (seq, num_local_experts), changing the shape to e.g. (3, 64) while indices stay (3, 8). All expert forward implementations (grouped_mm, batched_mm, eager) flatten both with reshape(-1) and rely on element-wise pairing.
This means we got mismatched shapes, the flattened tensors have different lengths, and every routing weight gets paired with the wrong expert.

Fix: Replace scatter+slice with masked_fill. We zero out non-local scores in-place, preserving the (seq, top_k) shape. I verified using EP=1,2,4 logits match non-EP ground truth within bf16 precision.

Bug 2: Weight loading uses wrong plan (modeling_utils.py:4281)

convert_and_load_state_dict_in_model builds a regex from tp_plan.keys() to decide which params to shard, then looks up model.tp_plan[pattern] to decide how. The call site passed model._tp_plan (raw TP plan), but the lookup uses the tp_plan property which returns _ep_plan when EP is enabled. When using an expert-only EP plan (no attention entries), the regex matches attention params from _tp_plan but the lookup in _ep_plan raises KeyError.

Fix: model._tp_planmodel.tp_plan so both regex and lookup use the same plan.

Bug 3: grouped_mm_experts_forward doesn't handle EP sentinels (moe.py)

After Bug 1 fix, RouterParallel correctly sets non-local expert indices to a sentinel value (num_local_experts). batched_mm already handles this (clamp + masked_fill). But grouped_mm passes sentinel IDs straight to torch.histc which doesn't count them (out of bin range) → grouped_mm never writes those output row. So if we have got uninitialized memory NaN propagating through residual connections to every token.

Fix: Same pattern as batched_mm, clamp sentinel IDs before sort, masked_fill after weighted multiply.

Tests added

  • test_router_parallel_preserves_shape: regression test for Bug 1
  • test_router_parallel_score_index_pairing : verifies element-wise score/index pairing across all EP ranks
  • test_grouped_gemm_shard_tensor_uses_rank_not_device : verifies shard uses mesh-local rank
  • test_grouped_mm_experts_sentinel_handling: verifies no NaN from sentinel expert IDs

End-to-end validation

Tested on 8×H100 with Qwen3-30B-A3B, input [[1, 2, 3]], comparing EP against non-EP ground truth.

Before submitting

  • This is not a pure code-agent PR: I found these bugs while benchmarking and wrote the fixes.
  • Read the contributor guideline.

Who can review?

@3outeille @ArthurZucker

@HuggingFaceDocBuilderDev

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ArthurZucker ArthurZucker left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A nice follow up to #43730!
Ty for working on this!

Comment thread src/transformers/modeling_utils.py
Comment thread tests/tensor_parallel/test_tensor_parallel.py
AmineDiro and others added 2 commits April 20, 2026 11:16
Add `attribute_map` to GptOssConfig to map `num_experts` to
`num_local_experts`,
and fix GroupedGemmParallel to use `empty_param.shape[0]` instead of
`module.num_experts`. Remove RouterParallel tests and add new expert
parallel
forward/backward tests to TensorParallelTesterMixin.
"""Load EP model and non-EP reference model for comparison."""
model_ep = model_class.from_pretrained(
model_path,
distributed_config=DistributedConfig(enable_expert_parallel=True),

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
distributed_config=DistributedConfig(enable_expert_parallel=True),
ep_plan = "auto",

might also be "more intuitive"?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(a note but we can fix this if it makes sense for you!) 🤗

@AmineDiro AmineDiro Apr 20, 2026

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I can understand tp_plan can be "auto" OR a dict for config ? There's no user-supplied dict for EP, the plan always comes from the config so it's not consistent with the other params.

I also saw the comment in src/transformers/distributed/configuration_utils.py

@dataclass
class DistributedConfig:
    enable_expert_parallel: bool = False
    # TODO: add tp_plan, pp_plan, device_mesh etc..

which makes me think thats the forward direction of the api? not sure really.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, I was wondering what was more intuitive for you!

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perf!

@AmineDiro

AmineDiro commented Apr 20, 2026

Copy link
Copy Markdown
Member Author

@ArthurZucker I see tests failing in CI.

  • All 5 failures: PreTrainedModel.init → _can_set_experts_implementation. It calls sys.modules[cls.module] and then key error (modeling_utils.py:1971, from batched and grouped experts implementations #42697).
  • I ran claude through and I think tests/utils/test_auto_docstring.py benchmarks import cost and deletes sys.modules[k] for every transformers.models.* without restoring them. The classes stay live (other namespaces hold refs), but sys.modules loses the entry.
  • On main, these two test files land in different CircleCI shards.

I can fix (one-liner, same pattern applies to _can_set_attn_implementation (1952)): 47497ce

  class_module = sys.modules.get(cls.__module__)
  if class_module is None or not hasattr(class_module, "__file__"):
      return False

I fold the None case into the existing branch that already returns False when __file__ is missing. Semantically, False means "we can't verify this class supports the new experts API, so assume it doesn't", which is exactly what the fallback does.

!! tradeoff: if the MoE model is being instantiated after something has evicted its entry from sys.modules (today, only test_auto_docstring._clear does this). In that case, the model silently falls back to "eager" experts
instead of grouped_mm. That's a perf regression, never a correctness issue. It only affects MoE models, non-MoE models (Bert, Bart, Mistral, etc., i.e. the 5 tests failing here) don't use experts at all, so "eager" is a no-op for them. If a user explicitly asked for experts_implementation="grouped_mm", they still get a clear ValueError rather than a silent downgrade.

…ction

Use sys.modules.get(cls.__module__) and treat a missing entry the same as
the existing Jupyter/REPL case (no __file__) -> return False.

Without this, PreTrainedModel.__init__ crashes with KeyError whenever
another test (e.g. tests/utils/test_auto_docstring._clear) has evicted
transformers.models.* entries from sys.modules while their class objects
are still live.
@github-actions

Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: gpt_oss

@ArthurZucker ArthurZucker left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good! merging

@ArthurZucker ArthurZucker added this pull request to the merge queue Apr 21, 2026
Merged via the queue into huggingface:main with commit 9dff7ca Apr 21, 2026
28 checks passed
Comment thread src/transformers/models/gpt_oss/configuration_gpt_oss.py
artem-spector pushed a commit to artem-spector/transformers that referenced this pull request Apr 21, 2026
…uggingface#45473)

* Fix EP: RouterParallel shape, tp_plan property, grouped_mm sentinels

* Fix GroupedGemmParallel.shard_tensor: use self.rank for expert sharding

* Fix expert parallel attribute mapping and update tests

Add `attribute_map` to GptOssConfig to map `num_experts` to
`num_local_experts`,
and fix GroupedGemmParallel to use `empty_param.shape[0]` instead of
`module.num_experts`. Remove RouterParallel tests and add new expert
parallel
forward/backward tests to TensorParallelTesterMixin.

* Harden _can_set_{attn,experts}_implementation against sys.modules eviction

Use sys.modules.get(cls.__module__) and treat a missing entry the same as
the existing Jupyter/REPL case (no __file__) -> return False.

Without this, PreTrainedModel.__init__ crashes with KeyError whenever
another test (e.g. tests/utils/test_auto_docstring._clear) has evicted
transformers.models.* entries from sys.modules while their class objects
are still live.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants