Skip to content

[bug fix][pp] fix weight load for qwen2.5-vl#15138

Merged
ShangmingCai merged 1 commit intosgl-project:mainfrom
openanolis:Xuchun/pp-vl-qwen
Dec 17, 2025
Merged

[bug fix][pp] fix weight load for qwen2.5-vl#15138
ShangmingCai merged 1 commit intosgl-project:mainfrom
openanolis:Xuchun/pp-vl-qwen

Conversation

@XucSh
Copy link
Copy Markdown
Collaborator

@XucSh XucSh commented Dec 15, 2025

Motivation

Now, There are two bug for Qwen2.5-vl

  1. when loading weights for Qwen2.5-vl, errors ouucrs(brought by EPD feature):
    File "/root/leipi/sglang/python/sglang/srt/managers/scheduler.py", line 2762, in run_scheduler_process
    scheduler = Scheduler(
    ^^^^^^^^^^
    File "/root/leipi/sglang/python/sglang/srt/managers/scheduler.py", line 323, in init
    self.tp_worker = TpModelWorker(
    ^^^^^^^^^^^^^^
    File "/root/leipi/sglang/python/sglang/srt/managers/tp_worker.py", line 248, in init
    self._model_runner = ModelRunner(
    ^^^^^^^^^^^^
    File "/root/leipi/sglang/python/sglang/srt/model_executor/model_runner.py", line 356, in init
    self.initialize(min_per_gpu_memory)
    File "/root/leipi/sglang/python/sglang/srt/model_executor/model_runner.py", line 432, in initialize
    self.load_model()
    File "/root/leipi/sglang/python/sglang/srt/model_executor/model_runner.py", line 798, in load_model
    self.model = get_model(
    ^^^^^^^^^^
    File "/root/leipi/sglang/python/sglang/srt/model_loader/init.py", line 28, in get_model
    return loader.load_model(
    ^^^^^^^^^^^^^^^^^^
    File "/root/leipi/sglang/python/sglang/srt/model_loader/loader.py", line 604, in load_model
    self.load_weights_and_postprocess(
    File "/root/leipi/sglang/python/sglang/srt/model_loader/loader.py", line 612, in load_weights_and_postprocess
    model.load_weights(weights)
    File "/root/leipi/sglang/python/sglang/srt/models/qwen2_5_vl.py", line 795, in load_weights
    raise ValueError(f"Weight {name} not found in params_dict")
    ValueError: Weight model.layers.13.mlp.down_proj.weight not found in params_dict

  2. in pp mode, the lm head is not loaded correctly which leads to garbage output.

Checklist

Signed-off-by: Xuchun Shang <xuchun.shang@gmail.com>
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @XucSh, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses two critical bugs affecting the loading and functionality of Qwen2.5-vl models. It fixes an error that previously halted the weight loading process when specific parameters were absent and ensures the correct initialization of the language model head when operating in pipeline parallelism mode, thereby improving model stability and output quality.

Highlights

  • Qwen2.5-vl Weight Loading Fix: Resolved an issue where loading weights for Qwen2.5-vl models would fail with a ValueError if certain weights were not found in the params_dict. The loader now gracefully skips these missing weights.
  • LM Head Loading in Pipeline Parallelism: Corrected the logic for loading the language model head in pipeline parallelism (PP) mode, ensuring it is properly initialized on the last rank to prevent garbage output.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces two fixes for weight loading in the Qwen2.5-VL model. The first fix correctly handles the loading of tied lm_head weights in a pipeline parallel setup. The second fix prevents crashes by ignoring weights from the checkpoint that are not present in the current model partition, which is a common scenario in pipeline parallelism.

My review includes one suggestion for improving the code structure. While the current fix for ignoring missing weights is effective, a more robust solution for layer-specific weights would be to refactor the pipeline parallelism layer check to a more appropriate location in the load_weights method. This would make the code cleaner and more aligned with similar model implementations in the repository.

Comment on lines 799 to +800
else:
if get_global_server_args().encoder_only:
continue
else:
raise ValueError(f"Weight {name} not found in params_dict")
continue
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This change correctly ignores weights that are not present in the current model partition, which is necessary for pipeline parallelism. This fixes the crash for weights like lm_head.weight on ranks where it doesn't exist.

However, for layer-specific weights (e.g., model.layers.13.mlp.down_proj.weight), the root cause of the ValueError is that the pipeline parallelism layer check is misplaced. It's currently inside the stacked_params_mapping loop (lines 764-774), so it doesn't apply to weights that are not in stacked_params_mapping.

For a more robust fix and to improve code structure, consider moving the layer_id check to the beginning of the main weight loading loop (right after line 742), similar to how it's done in Qwen2ForCausalLM.load_weights. This would ensure all layer-specific weights are correctly filtered based on the pipeline rank.

Example of the improved structure:

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
    # ...
    params_dict = dict(self.named_parameters(remove_duplicate=False))
    for name, loaded_weight in weights:
        layer_id = get_layer_id(name)
        if (
            layer_id is not None
            and hasattr(self, "model")
            and hasattr(self.model, "start_layer")
            and (
                layer_id < self.model.start_layer
                or layer_id >= self.model.end_layer
            )
        ):
            continue

        if "rotary_emb.inv_freq" in name:
            continue
        # ... rest of the loop

With that change, you would also need to remove the original layer_id check from inside the stacked_params_mapping loop. This change would make the code cleaner and more correct.

@XucSh
Copy link
Copy Markdown
Collaborator Author

XucSh commented Dec 15, 2025

/tag-and-rerun-ci

@XucSh XucSh changed the title [bug fix] fix weight load for qwen2.5-vl [bug fix][pp] fix weight load for qwen2.5-vl Dec 15, 2025
Copy link
Copy Markdown
Collaborator

@ShangmingCai ShangmingCai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@XucSh
Copy link
Copy Markdown
Collaborator Author

XucSh commented Dec 15, 2025

/rerun-failed-ci

1 similar comment
@XucSh
Copy link
Copy Markdown
Collaborator Author

XucSh commented Dec 17, 2025

/rerun-failed-ci

@ShangmingCai
Copy link
Copy Markdown
Collaborator

@gty111 Can you double-check on this PR?

@gty111
Copy link
Copy Markdown
Contributor

gty111 commented Dec 17, 2025

@gty111 Can you double-check on this PR?

I’ve verified this PR locally, and in my tests it works well with both colocate and EPD. Thanks for the fix.

@ShangmingCai ShangmingCai merged commit 0071fe9 into sgl-project:main Dec 17, 2025
423 of 452 checks passed
Prozac614 pushed a commit to Prozac614/sglang that referenced this pull request Dec 23, 2025
Signed-off-by: Xuchun Shang <xuchun.shang@gmail.com>
jiaming1130 pushed a commit to zhuyijie88/sglang that referenced this pull request Dec 25, 2025
Signed-off-by: Xuchun Shang <xuchun.shang@gmail.com>
YChange01 pushed a commit to YChange01/sglang that referenced this pull request Jan 13, 2026
Signed-off-by: Xuchun Shang <xuchun.shang@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants