Skip to content

Fix Qwen3.5 MoE KeyError with pipeline parallelism#21217

Open
he-yufeng wants to merge 1 commit intosgl-project:mainfrom
he-yufeng:fix/qwen3.5-pp-load-weights
Open

Fix Qwen3.5 MoE KeyError with pipeline parallelism#21217
he-yufeng wants to merge 1 commit intosgl-project:mainfrom
he-yufeng:fix/qwen3.5-pp-load-weights

Conversation

@he-yufeng
Copy link
Copy Markdown
Contributor

Motivation

Running Qwen3.5-122B-A10B with --pipeline-parallel-size 8 crashes during weight loading:

KeyError: 'model.layers.4.mlp.experts.w13_weight'

The fused expert path (load_fused_expert_weights) was already fixed in #21070 to handle missing params under PP, but the non-fused expert else branch was missed — it still does a bare params_dict[name_mapped] without checking if the key exists.

With pipeline parallelism, layers assigned to other ranks don't have their parameters in the local params_dict, so any access without a guard will KeyError.

Modifications

Added if name_mapped not in params_dict: continue before the dict access in the non-fused expert path, matching the pattern already used in:

Applied to both Qwen3_5MoeForCausalLM and Qwen3_5MoeForConditionalGeneration.

Fixes #21184

…rallelism

When running Qwen3.5-122B with pp>1, the non-fused expert weight loading
path in load_weights accesses params_dict[name_mapped] without checking
if the key exists. With pipeline parallelism, layers assigned to other
ranks won't have their parameters in the local params_dict, causing a
KeyError (e.g., 'model.layers.4.mlp.experts.w13_weight').

The fused expert path (load_fused_expert_weights) was already fixed in
sgl-project#21070 but the else branch for non-fused experts was missed. This adds
the same guard to both Qwen3_5MoeForCausalLM and
Qwen3_5MoeForConditionalGeneration.

Fixes sgl-project#21184
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical issue preventing the use of pipeline parallelism with Qwen3.5 MoE models by resolving a KeyError during weight loading. The change ensures that the system correctly handles distributed parameters by adding a necessary existence check, thereby enabling robust operation of these large models in a parallelized environment.

Highlights

  • Fix KeyError in Qwen3.5 MoE with Pipeline Parallelism: Resolved a KeyError: 'model.layers.X.mlp.experts.w13_weight' that occurred when running Qwen3.5-122B-A10B with --pipeline-parallel-size 8 during weight loading.
  • Guard for Non-Fused Expert Loading: Introduced a check if name_mapped not in params_dict: continue in the non-fused expert loading path within load_fused_expert_weights. This prevents attempts to access parameters not present on the local rank when pipeline parallelism is active.
  • Consistency with Existing Guards: The added guard aligns the non-fused expert loading logic with the already guarded fused expert weights loading (fixed in [Qwen3.5] Fix broken pipeline parallelism layer splitting #21070) and other Qwen3.5 MoE related files (qwen3_5_mtp.py, qwen3_omni_moe.py).
  • Affected Models: The fix was applied to both Qwen3_5MoeForCausalLM and Qwen3_5MoeForConditionalGeneration within the qwen3_5.py file.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds checks to handle cases where name_mapped might not be present in params_dict during parallel processing, preventing KeyError. The review comments suggest that these new checks introduce redundancy with existing conditional continue statements, and recommend refactoring to improve code clarity by removing the now-superfluous preceding checks.

Comment on lines +1252 to +1253
if name_mapped not in params_dict:
continue
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

While this check correctly fixes the KeyError, it also makes the preceding check on lines 1246-1250 redundant. The logic if (A and B) continue; followed by if (B) continue; is equivalent to just if (B) continue;.

To improve code clarity, please remove the redundant check on lines 1246-1250 and keep only this new one.

Comment on lines +1596 to +1597
if name_mapped not in params_dict:
continue
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the change above, this check makes the preceding check on lines 1587-1591 redundant. The logic if (A and B) continue; followed by if (B) continue; is equivalent to just if (B) continue;.

To improve code clarity, please remove the redundant check on lines 1587-1591 and keep only this new one.

@he-yufeng
Copy link
Copy Markdown
Contributor Author

ping — Qwen3.5 MoE + PP KeyError, non-fused expert path missing PP guard.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] fix pp for qwen3_5 (KeyError when reading params)

1 participant