[Bugfix] Fix missing sequence_lengths in qwen3_omni_moe_thinker#35741
Conversation
PR vllm-project#34580 added a `sequence_lengths` parameter to `Qwen2_5_VisionAttention.forward()` for the FlashInfer cuDNN backend and updated callers in `qwen3_vl.py` and `qwen2_5_vl.py`, but missed updating `qwen3_omni_moe_thinker.py`. This causes a TypeError crash when loading any Qwen3-Omni model: ``` TypeError: Qwen2_5_VisionAttention.forward() missing 1 required positional argument: 'sequence_lengths' ``` Fix: - Add `sequence_lengths` parameter to `Qwen3OmniMoeThinkerVisionBlock.forward()` and pass it through to `self.attn()` - Compute `sequence_lengths` in the encoder loop using `MMEncoderAttention.maybe_compute_sequence_lengths()`, mirroring the pattern in `qwen3_vl.py` - Build `cu_seqlens_np` from `grid_thw` directly in numpy to avoid a GPU->CPU sync on the CUDA `cu_seqlens` tensor Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
There was a problem hiding this comment.
Code Review
This pull request fixes a TypeError in qwen3_omni_moe_thinker by adding the missing sequence_lengths parameter. The changes correctly propagate this parameter to the attention layer. However, I've found a potential critical issue in the logic used to compute sequence_lengths. The calculation seems to be on a per-frame basis instead of per-image, which contradicts how cu_seqlens is handled elsewhere in the model and could lead to incorrect attention results. I've provided a detailed comment with a suggested fix.
| cu_seqlens_np = np.repeat( | ||
| grid_thw_np[:, 1] * grid_thw_np[:, 2], grid_thw_np[:, 0] | ||
| ).cumsum(axis=0, dtype=np.int32) | ||
| cu_seqlens_np = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens_np]) |
There was a problem hiding this comment.
The logic to reconstruct cu_seqlens_np appears to be incorrect. While the PR description mentions this pattern is mirrored from qwen3_vl.py, it seems to calculate cumulative sequence lengths on a per-frame basis, rather than per-image, which is inconsistent with how cu_seqlens is handled elsewhere in this model.
Based on the logic in Qwen3OmniMoeThinkerVisionTower.forward (line 813), the sequence length for each image should be num_frames * height * width, which corresponds to grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2].
Using np.repeat as in the current implementation will result in incorrect sequence_lengths being passed to the attention layer, which expects per-sequence (image) lengths. This could lead to incorrect attention computations, especially for multi-frame inputs.
The suggested change corrects this by calculating sequence lengths per image before taking the cumulative sum, aligning with how cu_seqlens is typically constructed for vision models in vLLM.
seq_lens_np = grid_thw_np[:, 0] * grid_thw_np[:, 1] * grid_thw_np[:, 2]
cu_seqlens_np = np.concatenate(
[np.zeros(1, dtype=np.int32), seq_lens_np.cumsum(dtype=np.int32)])There was a problem hiding this comment.
Hmm I think this suggestion is wrong. Qwen2_5_VisionAttention is per frame
…-project#35741) Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
…-project#35741) Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
…-project#35741) Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
…-project#35741) Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
…-project#35741) Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
…-project#35741) Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
…-project#35741) Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
PR vllm-project#42787 made the Qwen2.5-VL vision backbone pass `sequence_lengths` (FlashInfer CuDNN metadata) to every vision block, but the EXAONE-4.5 overrides of the vision block and attention kept the pre-vllm-project#42787 signature. Since EXAONE-4.5 inherits `Qwen2_5_VisionTransformer.forward`, any multimodal request now fails with: TypeError: Exaone4_5_VisionBlock.forward() got an unexpected keyword argument 'sequence_lengths' Thread `sequence_lengths` through `Exaone4_5_VisionBlock` and `EXAONE4_5_VisionAttention` into `MMEncoderAttention`, and register it in the block's `dynamic_arg_dims` for torch.compile, mirroring the equivalent fix for qwen3_omni_moe_thinker in vllm-project#35741. Co-authored-by: Claude <noreply@anthropic.com> Signed-off-by: Jongsu Liam Kim <jongsukim8@gmail.com>
Purpose
PR #34580 added a
sequence_lengthsparameter toQwen2_5_VisionAttention.forward()for the FlashInfer cuDNN backend and updated callers inqwen3_vl.pyandqwen2_5_vl.py, but missed updatingqwen3_omni_moe_thinker.py. This causes a TypeError crash when loading any Qwen3-Omni model:Fix:
sequence_lengthsparameter toQwen3OmniMoeThinkerVisionBlock.forward()and pass it through toself.attn()sequence_lengthsin the encoder loop usingMMEncoderAttention.maybe_compute_sequence_lengths(), mirroring the pattern inqwen3_vl.pycu_seqlens_npfromgrid_thwdirectly in numpy to avoid a GPU->CPU sync on the CUDAcu_seqlenstensorTest Plan
e2e run
Test Result
h100
gb200
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.