[Model] Refactor Gemma4 vision tower with vLLM-native modules#43440
[Model] Refactor Gemma4 vision tower with vLLM-native modules#43440linitra24 wants to merge 11 commits into
Conversation
There was a problem hiding this comment.
Code Review
This pull request replaces the transformers-based vision tower in the Gemma 4 multimodal model with a native vLLM implementation to support LoRA visibility and optimized execution. The new implementation introduces custom modules for patch embedding, multidimensional rotary embeddings, and a vision pooler with position-based average pooling. Review feedback identifies a potential indexing error in the pooling logic when image dimensions are not multiples of the kernel size and suggests refining the pooler mask to correctly handle padding patches that might otherwise be incorrectly marked as valid tokens.
| max_x = clamped_positions[..., 0].max(dim=-1, keepdim=True)[0] + 1 | ||
| kernel_idxs = torch.div(clamped_positions, k, rounding_mode="floor") | ||
| kernel_idxs = kernel_idxs[..., 0] + (max_x // k) * kernel_idxs[..., 1] |
There was a problem hiding this comment.
The calculation of kernel_idxs assumes that the image width (in patches) is a multiple of the pooling kernel size k. While Gemma4ProcessingInfo enforces this for standard inputs, Gemma4VisionPooler should be robust to cases where max_x is not a multiple of k. If max_x % k != 0, kernel_idxs can exceed length - 1, causing F.one_hot to crash or produce incorrect results. A safer approach is to use the actual number of horizontal blocks (max_x + k - 1) // k for indexing.
| max_x = clamped_positions[..., 0].max(dim=-1, keepdim=True)[0] + 1 | |
| kernel_idxs = torch.div(clamped_positions, k, rounding_mode="floor") | |
| kernel_idxs = kernel_idxs[..., 0] + (max_x // k) * kernel_idxs[..., 1] | |
| max_x = clamped_positions[..., 0].max(dim=-1, keepdim=True)[0] + 1 | |
| num_blocks_x = (max_x + k - 1) // k | |
| kernel_idxs = torch.div(clamped_positions, k, rounding_mode="floor") | |
| kernel_idxs = kernel_idxs[..., 0] + num_blocks_x * kernel_idxs[..., 1] |
| kernel_idxs = kernel_idxs[..., 0] + (max_x // k) * kernel_idxs[..., 1] | ||
| weights = F.one_hot(kernel_idxs.long(), length).float() / k_squared | ||
| output = weights.transpose(1, 2) @ hidden_states.float() | ||
| mask = torch.logical_not((weights == 0).all(dim=1)) |
There was a problem hiding this comment.
The pooler_mask calculation using (weights == 0).all(dim=1) will incorrectly include the first pooled token (index 0) as valid even if it only contains padding patches. This is because all padding patches (with pixel_position_ids == -1) are clamped to (0, 0) and thus mapped to kernel_idxs == 0. Since hidden_states are masked to 0.0 for padding positions, the sum is correct, but the mask should ideally only consider tokens that contain at least one real patch to avoid returning vestigial tokens in edge cases.
mask = torch.logical_not((weights == 0).all(dim=1))
# Ensure index 0 is only valid if it contains at least one real patch
real_patches_mask = torch.logical_not(padding_positions)
mask[:, 0] = (weights[:, real_patches_mask, 0] > 0).any(dim=1) if real_patches_mask.any() else False| if self.config.standardize: | ||
| hidden_states = (hidden_states - self.std_bias) * self.std_scale | ||
|
|
||
| return BaseModelOutputWithPast(last_hidden_state=hidden_states) |
There was a problem hiding this comment.
| return BaseModelOutputWithPast(last_hidden_state=hidden_states) | |
| return hidden_states |
There was a problem hiding this comment.
I've addressed this.
| attn_output = _gemma4_vision_attention_forward( | ||
| self, query_states, key_states, value_states, attention_mask | ||
| ) |
There was a problem hiding this comment.
Use MMEncoderAttention interface instead.
There was a problem hiding this comment.
For MMEncoderAttention, I looked into it, but Gemma4 vision currently relies on a per-sample bool padding mask derived from pixel_position_ids == -1. MMEncoderAttention is built around packed/dense encoder attention with cu_seqlens and does not directly support this arbitrary padding mask. Converting it correctly would require an additional pack/scatter path for valid patch tokens, so I kept the SDPA path in this PR and would prefer to handle the MMEncoderAttention conversion as a follow-up!
| def _apply_multidimensional_rope( | ||
| hidden_states: torch.Tensor, | ||
| cos: torch.Tensor, | ||
| sin: torch.Tensor, | ||
| position_ids: torch.Tensor, | ||
| unsqueeze_dim: int = 2, | ||
| ) -> torch.Tensor: | ||
| ndim = position_ids.shape[-1] | ||
| num_input_channels = hidden_states.shape[-1] | ||
| num_rotated_channels_per_dim = 2 * (num_input_channels // (2 * ndim)) | ||
| if num_rotated_channels_per_dim <= 0: | ||
| raise ValueError( | ||
| "Invalid configuration: num_rotated_channels_per_dim must be > 0, " | ||
| f"got {num_rotated_channels_per_dim}." | ||
| ) | ||
|
|
||
| split_sizes = [num_rotated_channels_per_dim] * ndim | ||
| hidden_parts = torch.split(hidden_states, split_sizes, dim=-1) | ||
| cos_parts = torch.split(cos, split_sizes, dim=-1) | ||
| sin_parts = torch.split(sin, split_sizes, dim=-1) | ||
| output_parts = [ | ||
| _apply_rotary_pos_emb( | ||
| hidden_states=hidden_parts[k], | ||
| cos=cos_parts[k], | ||
| sin=sin_parts[k], | ||
| unsqueeze_dim=unsqueeze_dim, | ||
| ) | ||
| for k in range(ndim) | ||
| ] | ||
| return torch.cat(output_parts, dim=-1) |
There was a problem hiding this comment.
I think we can use ApplyRotaryEmb:
vllm/vllm/model_executor/layers/rotary_embedding/common.py
Lines 122 to 127 in 5bb8d27
There was a problem hiding this comment.
Thanks for the suggestion. I looked into using ApplyRotaryEmb here, but kept the HF-style multidimensional RoPE helper for now.
This is similar to the MMEncoderAttention case: Gemma4 vision currently uses batch-specific 2D pixel_position_ids with padded image patches, while the ApplyRotaryEmb path expects a more standard/packed layout. Switching this correctly would require additional layout changes, which I think is better handled as a follow-up.
For this PR, I kept the RoPE behavior aligned with Transformers and focused on moving the vision tower to vLLM-native modules for future LoRA support.
| position_ids=pixel_position_ids, | ||
| ) | ||
|
|
||
| return BaseModelOutputWithPast(last_hidden_state=hidden_states) |
There was a problem hiding this comment.
| return BaseModelOutputWithPast(last_hidden_state=hidden_states) | |
| return last_hidden_states |
There was a problem hiding this comment.
I've addressed this.
This PR refactors the Gemma4 multimodal vision tower to use vLLM-native modules instead of directly relying on the Transformers vision tower implementation.
The main goal is to make Gemma4 vision tower layers visible to vLLM’s module system, which is a prerequisite for future vision tower LoRA support. This work is related to #42662.
Notes
The implementation does not use
MMEncoderAttentionyet. Gemma4 image inputs are currently represented as padded patch sequences with a per-token bool padding mask frompixel_position_ids == -1.MMEncoderAttentionis better suited for dense or packed vision sequences usingcu_seqlens, so using it correctly would require an additional pack/scatter path for valid patch tokens. For this refactor, the attention path keeps the padded layout and uses SDPA with the existing bool mask.Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.