Fix EP: RouterParallel shape, tp_plan property, grouped_mm sentinels#45473
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
ArthurZucker
left a comment
There was a problem hiding this comment.
A nice follow up to #43730!
Ty for working on this!
Add `attribute_map` to GptOssConfig to map `num_experts` to `num_local_experts`, and fix GroupedGemmParallel to use `empty_param.shape[0]` instead of `module.num_experts`. Remove RouterParallel tests and add new expert parallel forward/backward tests to TensorParallelTesterMixin.
| """Load EP model and non-EP reference model for comparison.""" | ||
| model_ep = model_class.from_pretrained( | ||
| model_path, | ||
| distributed_config=DistributedConfig(enable_expert_parallel=True), |
There was a problem hiding this comment.
| distributed_config=DistributedConfig(enable_expert_parallel=True), | |
| ep_plan = "auto", |
might also be "more intuitive"?
There was a problem hiding this comment.
(a note but we can fix this if it makes sense for you!) 🤗
There was a problem hiding this comment.
From what I can understand tp_plan can be "auto" OR a dict for config ? There's no user-supplied dict for EP, the plan always comes from the config so it's not consistent with the other params.
I also saw the comment in src/transformers/distributed/configuration_utils.py
@dataclass
class DistributedConfig:
enable_expert_parallel: bool = False
# TODO: add tp_plan, pp_plan, device_mesh etc..which makes me think thats the forward direction of the api? not sure really.
There was a problem hiding this comment.
Yep, I was wondering what was more intuitive for you!
|
@ArthurZucker I see tests failing in CI.
I can fix (one-liner, same pattern applies to class_module = sys.modules.get(cls.__module__)
if class_module is None or not hasattr(class_module, "__file__"):
return FalseI fold the None case into the existing branch that already returns False when !! tradeoff: if the MoE model is being instantiated after something has evicted its entry from |
…ction Use sys.modules.get(cls.__module__) and treat a missing entry the same as the existing Jupyter/REPL case (no __file__) -> return False. Without this, PreTrainedModel.__init__ crashes with KeyError whenever another test (e.g. tests/utils/test_auto_docstring._clear) has evicted transformers.models.* entries from sys.modules while their class objects are still live.
|
[For maintainers] Suggested jobs to run (before merge) run-slow: gpt_oss |
ArthurZucker
left a comment
There was a problem hiding this comment.
sounds good! merging
…uggingface#45473) * Fix EP: RouterParallel shape, tp_plan property, grouped_mm sentinels * Fix GroupedGemmParallel.shard_tensor: use self.rank for expert sharding * Fix expert parallel attribute mapping and update tests Add `attribute_map` to GptOssConfig to map `num_experts` to `num_local_experts`, and fix GroupedGemmParallel to use `empty_param.shape[0]` instead of `module.num_experts`. Remove RouterParallel tests and add new expert parallel forward/backward tests to TensorParallelTesterMixin. * Harden _can_set_{attn,experts}_implementation against sys.modules eviction Use sys.modules.get(cls.__module__) and treat a missing entry the same as the existing Jupyter/REPL case (no __file__) -> return False. Without this, PreTrainedModel.__init__ crashes with KeyError whenever another test (e.g. tests/utils/test_auto_docstring._clear) has evicted transformers.models.* entries from sys.modules while their class objects are still live.
While benchmarking Qwen3-30B-A3B SFT training with Expert Parallelism (EP) using TRL, I found three bugs that combine to produce silently wrong results or NaN loss. Every existing test uses
tp_plan="auto"which bypassesRouterParalleland EP sentinel handling entirely if I understand correctly.These fixes are related to #45436 (will update with the expert-only EP config once this lands 👍🏼 )
Bug 1:
RouterParallelbreaks score/index pairing (tensor_parallel.py)The router produces three element-wise paired tensors:
router_logits (seq, num_experts),router_scores (seq, top_k),router_indices (seq, top_k)._prepare_output_fnis supposed to remap these to local expert space per EP rank.The old code did
scatterscores into a full(seq, num_experts)matrix then sliced to(seq, num_local_experts), changing the shape to e.g.(3, 64)while indices stay(3, 8). All expert forward implementations (grouped_mm,batched_mm, eager) flatten both withreshape(-1)and rely on element-wise pairing.This means we got mismatched shapes, the flattened tensors have different lengths, and every routing weight gets paired with the wrong expert.
Fix: Replace
scatter+slicewithmasked_fill. We zero out non-local scores in-place, preserving the(seq, top_k)shape. I verified using EP=1,2,4 logits match non-EP ground truth within bf16 precision.Bug 2: Weight loading uses wrong plan (
modeling_utils.py:4281)convert_and_load_state_dict_in_modelbuilds a regex fromtp_plan.keys()to decide which params to shard, then looks upmodel.tp_plan[pattern]to decide how. The call site passedmodel._tp_plan(raw TP plan), but the lookup uses thetp_planproperty which returns_ep_planwhen EP is enabled. When using an expert-only EP plan (no attention entries), the regex matches attention params from_tp_planbut the lookup in_ep_planraisesKeyError.Fix:
model._tp_plan→model.tp_planso both regex and lookup use the same plan.Bug 3:
grouped_mm_experts_forwarddoesn't handle EP sentinels (moe.py)After Bug 1 fix,
RouterParallelcorrectly sets non-local expert indices to a sentinel value (num_local_experts).batched_mmalready handles this (clamp + masked_fill). Butgrouped_mmpasses sentinel IDs straight totorch.histcwhich doesn't count them (out of bin range) →grouped_mmnever writes those output row. So if we have got uninitialized memory NaN propagating through residual connections to every token.Fix: Same pattern as
batched_mm, clamp sentinel IDs before sort,masked_fillafter weighted multiply.Tests added
test_router_parallel_preserves_shape: regression test for Bug 1test_router_parallel_score_index_pairing: verifies element-wise score/index pairing across all EP rankstest_grouped_gemm_shard_tensor_uses_rank_not_device: verifies shard uses mesh-local ranktest_grouped_mm_experts_sentinel_handling: verifies no NaN from sentinel expert IDsEnd-to-end validation
Tested on 8×H100 with Qwen3-30B-A3B, input
[[1, 2, 3]], comparing EP against non-EP ground truth.Before submitting
Who can review?
@3outeille @ArthurZucker