Skip to content

[Bugfix] Fix missing sequence_lengths in qwen3_omni_moe_thinker#35741

Merged
ywang96 merged 1 commit into
vllm-project:mainfrom
yeqcharlotte:fix/qwen3-omni-moe-sequence-lengths
Mar 2, 2026
Merged

[Bugfix] Fix missing sequence_lengths in qwen3_omni_moe_thinker#35741
ywang96 merged 1 commit into
vllm-project:mainfrom
yeqcharlotte:fix/qwen3-omni-moe-sequence-lengths

Conversation

@yeqcharlotte

@yeqcharlotte yeqcharlotte commented Mar 2, 2026

Copy link
Copy Markdown
Collaborator

Purpose

PR #34580 added a sequence_lengths parameter to
Qwen2_5_VisionAttention.forward() for the FlashInfer cuDNN backend and updated callers in qwen3_vl.py and qwen2_5_vl.py, but missed updating qwen3_omni_moe_thinker.py. This causes a TypeError crash when loading any Qwen3-Omni model:

TypeError: Qwen2_5_VisionAttention.forward() missing 1 required
positional argument: 'sequence_lengths'

Fix:

  • Add sequence_lengths parameter to Qwen3OmniMoeThinkerVisionBlock.forward() and pass it through to self.attn()
  • Compute sequence_lengths in the encoder loop using MMEncoderAttention.maybe_compute_sequence_lengths(), mirroring the pattern in qwen3_vl.py
  • Build cu_seqlens_np from grid_thw directly in numpy to avoid a GPU->CPU sync on the CUDA cu_seqlens tensor

Test Plan

e2e run

Test Result

h100

[2026-03-02 01:44:49,560] [rank 0] [INFO] Evaluation results on task mmmu_thinking_v0: mmmu_accuracy: 0.73
Peak TPGS: 616.0 (adjust if server does not use all allocated GPUs)
Peak Prefill TPGS: 17036.0 (prefill tokens per GPU per second)
Ran 21/21 requests in 294.48s
Success Rate: 100.00%
Peak QPS: 0.27
Server Avg TTFT: 285359ms
Server P50 TTFT: 285359ms
Server P99 TTFT: 285359ms
Server Avg TTIT: 35ms
Server P50 TTIT: 33ms
Server P99 TTIT: 35ms
Server Output Tokens Per Second: 71
Avg Prefill Len: 18979
P50 Prefill Len: 18979
P99 Prefill Len: 19003
Avg Decode Len: 2142
P50 Decode Len: 2142
P99 Decode Len: 2786

gb200

[2026-03-02 02:56:03,531] [rank 0] [INFO] Evaluation results on task mmmu_thinking_v0: mmmu_accuracy: 0.67625
Peak TPGS: 1954.0 (adjust if server does not use all allocated GPUs)
Peak Prefill TPGS: 18314.0 (prefill tokens per GPU per second)
Ran 89/89 requests in 329.84s
Success Rate: 100.00%
Peak QPS: 0.48
Server Avg TTFT: 352ms
Server P50 TTFT: 337ms
Server P99 TTFT: 537ms
Server Avg TTIT: 16ms
Server P50 TTIT: 15ms
Server P99 TTIT: 18ms
Server Output Tokens Per Second: 963
Avg Prefill Len: 18902
P50 Prefill Len: 18902
P99 Prefill Len: 18997
Avg Decode Len: 4129
P50 Decode Len: 2961
P99 Decode Len: 7310

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

PR vllm-project#34580 added a `sequence_lengths` parameter to
`Qwen2_5_VisionAttention.forward()` for the FlashInfer cuDNN backend
and updated callers in `qwen3_vl.py` and `qwen2_5_vl.py`, but missed
updating `qwen3_omni_moe_thinker.py`. This causes a TypeError crash
when loading any Qwen3-Omni model:

```
TypeError: Qwen2_5_VisionAttention.forward() missing 1 required
positional argument: 'sequence_lengths'
```

Fix:
- Add `sequence_lengths` parameter to
  `Qwen3OmniMoeThinkerVisionBlock.forward()` and pass it through to
  `self.attn()`
- Compute `sequence_lengths` in the encoder loop using
  `MMEncoderAttention.maybe_compute_sequence_lengths()`, mirroring the
  pattern in `qwen3_vl.py`
- Build `cu_seqlens_np` from `grid_thw` directly in numpy to avoid a
  GPU->CPU sync on the CUDA `cu_seqlens` tensor

Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
@yeqcharlotte yeqcharlotte requested a review from ywang96 March 2, 2026 12:16
@yeqcharlotte yeqcharlotte requested a review from sighingnow as a code owner March 2, 2026 12:16
@mergify mergify Bot added qwen Related to Qwen models bug Something isn't working labels Mar 2, 2026

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request fixes a TypeError in qwen3_omni_moe_thinker by adding the missing sequence_lengths parameter. The changes correctly propagate this parameter to the attention layer. However, I've found a potential critical issue in the logic used to compute sequence_lengths. The calculation seems to be on a per-frame basis instead of per-image, which contradicts how cu_seqlens is handled elsewhere in the model and could lead to incorrect attention results. I've provided a detailed comment with a suggested fix.

Comment on lines +982 to +985
cu_seqlens_np = np.repeat(
grid_thw_np[:, 1] * grid_thw_np[:, 2], grid_thw_np[:, 0]
).cumsum(axis=0, dtype=np.int32)
cu_seqlens_np = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens_np])

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The logic to reconstruct cu_seqlens_np appears to be incorrect. While the PR description mentions this pattern is mirrored from qwen3_vl.py, it seems to calculate cumulative sequence lengths on a per-frame basis, rather than per-image, which is inconsistent with how cu_seqlens is handled elsewhere in this model.

Based on the logic in Qwen3OmniMoeThinkerVisionTower.forward (line 813), the sequence length for each image should be num_frames * height * width, which corresponds to grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2].

Using np.repeat as in the current implementation will result in incorrect sequence_lengths being passed to the attention layer, which expects per-sequence (image) lengths. This could lead to incorrect attention computations, especially for multi-frame inputs.

The suggested change corrects this by calculating sequence lengths per image before taking the cumulative sum, aligning with how cu_seqlens is typically constructed for vision models in vLLM.

seq_lens_np = grid_thw_np[:, 0] * grid_thw_np[:, 1] * grid_thw_np[:, 2]
cu_seqlens_np = np.concatenate(
    [np.zeros(1, dtype=np.int32), seq_lens_np.cumsum(dtype=np.int32)])

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I think this suggestion is wrong. Qwen2_5_VisionAttention is per frame

@yeqcharlotte yeqcharlotte added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 2, 2026

@ywang96 ywang96 left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix

@ywang96 ywang96 enabled auto-merge (squash) March 2, 2026 21:05
@ywang96 ywang96 merged commit fa6a6be into vllm-project:main Mar 2, 2026
61 checks passed
Copilot AI pushed a commit to machov/vllm that referenced this pull request Mar 10, 2026
avinashsingh77 pushed a commit to avinashsingh77/vllm that referenced this pull request Mar 12, 2026
wendyliu235 pushed a commit to wendyliu235/vllm-public that referenced this pull request Mar 18, 2026
mystous pushed a commit to mystous/vllm_hybrid that referenced this pull request May 10, 2026
my-other-github-account pushed a commit to my-other-github-account/vllm that referenced this pull request May 15, 2026
my-other-github-account pushed a commit to my-other-github-account/vllm that referenced this pull request May 15, 2026
0826joyce pushed a commit to 0826joyce/vllm-serving-optimization that referenced this pull request May 19, 2026
appleparan added a commit to appleparan/vllm that referenced this pull request Jun 10, 2026
PR vllm-project#42787 made the Qwen2.5-VL vision backbone pass `sequence_lengths`
(FlashInfer CuDNN metadata) to every vision block, but the EXAONE-4.5
overrides of the vision block and attention kept the pre-vllm-project#42787
signature. Since EXAONE-4.5 inherits `Qwen2_5_VisionTransformer.forward`,
any multimodal request now fails with:

    TypeError: Exaone4_5_VisionBlock.forward() got an unexpected
    keyword argument 'sequence_lengths'

Thread `sequence_lengths` through `Exaone4_5_VisionBlock` and
`EXAONE4_5_VisionAttention` into `MMEncoderAttention`, and register it
in the block's `dynamic_arg_dims` for torch.compile, mirroring the
equivalent fix for qwen3_omni_moe_thinker in vllm-project#35741.

Co-authored-by: Claude <noreply@anthropic.com>
Signed-off-by: Jongsu Liam Kim <jongsukim8@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants