Skip to content

[Model] Refactor Gemma4 vision tower with vLLM-native modules#43440

Closed
linitra24 wants to merge 11 commits into
vllm-project:mainfrom
linitra24:gemma4-vision-tower-refactor
Closed

[Model] Refactor Gemma4 vision tower with vLLM-native modules#43440
linitra24 wants to merge 11 commits into
vllm-project:mainfrom
linitra24:gemma4-vision-tower-refactor

Conversation

@linitra24

@linitra24 linitra24 commented May 22, 2026

Copy link
Copy Markdown
Contributor

This PR refactors the Gemma4 multimodal vision tower to use vLLM-native modules instead of directly relying on the Transformers vision tower implementation.

The main goal is to make Gemma4 vision tower layers visible to vLLM’s module system, which is a prerequisite for future vision tower LoRA support. This work is related to #42662.

Notes

The implementation does not use MMEncoderAttention yet. Gemma4 image inputs are currently represented as padded patch sequences with a per-token bool padding mask from pixel_position_ids == -1. MMEncoderAttention is better suited for dense or packed vision sequences using cu_seqlens, so using it correctly would require an additional pack/scatter path for valid patch tokens. For this refactor, the attention path keeps the padded layout and uses SDPA with the existing bool mask.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request replaces the transformers-based vision tower in the Gemma 4 multimodal model with a native vLLM implementation to support LoRA visibility and optimized execution. The new implementation introduces custom modules for patch embedding, multidimensional rotary embeddings, and a vision pooler with position-based average pooling. Review feedback identifies a potential indexing error in the pooling logic when image dimensions are not multiples of the kernel size and suggests refining the pooler mask to correctly handle padding patches that might otherwise be incorrectly marked as valid tokens.

Comment on lines +976 to +978
max_x = clamped_positions[..., 0].max(dim=-1, keepdim=True)[0] + 1
kernel_idxs = torch.div(clamped_positions, k, rounding_mode="floor")
kernel_idxs = kernel_idxs[..., 0] + (max_x // k) * kernel_idxs[..., 1]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The calculation of kernel_idxs assumes that the image width (in patches) is a multiple of the pooling kernel size k. While Gemma4ProcessingInfo enforces this for standard inputs, Gemma4VisionPooler should be robust to cases where max_x is not a multiple of k. If max_x % k != 0, kernel_idxs can exceed length - 1, causing F.one_hot to crash or produce incorrect results. A safer approach is to use the actual number of horizontal blocks (max_x + k - 1) // k for indexing.

Suggested change
max_x = clamped_positions[..., 0].max(dim=-1, keepdim=True)[0] + 1
kernel_idxs = torch.div(clamped_positions, k, rounding_mode="floor")
kernel_idxs = kernel_idxs[..., 0] + (max_x // k) * kernel_idxs[..., 1]
max_x = clamped_positions[..., 0].max(dim=-1, keepdim=True)[0] + 1
num_blocks_x = (max_x + k - 1) // k
kernel_idxs = torch.div(clamped_positions, k, rounding_mode="floor")
kernel_idxs = kernel_idxs[..., 0] + num_blocks_x * kernel_idxs[..., 1]

kernel_idxs = kernel_idxs[..., 0] + (max_x // k) * kernel_idxs[..., 1]
weights = F.one_hot(kernel_idxs.long(), length).float() / k_squared
output = weights.transpose(1, 2) @ hidden_states.float()
mask = torch.logical_not((weights == 0).all(dim=1))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The pooler_mask calculation using (weights == 0).all(dim=1) will incorrectly include the first pooled token (index 0) as valid even if it only contains padding patches. This is because all padding patches (with pixel_position_ids == -1) are clamped to (0, 0) and thus mapped to kernel_idxs == 0. Since hidden_states are masked to 0.0 for padding positions, the sum is correct, but the mask should ideally only consider tokens that contain at least one real patch to avoid returning vestigial tokens in edge cases.

        mask = torch.logical_not((weights == 0).all(dim=1))
        # Ensure index 0 is only valid if it contains at least one real patch
        real_patches_mask = torch.logical_not(padding_positions)
        mask[:, 0] = (weights[:, real_patches_mask, 0] > 0).any(dim=1) if real_patches_mask.any() else False

Comment thread vllm/model_executor/models/gemma4_mm.py Outdated
if self.config.standardize:
hidden_states = (hidden_states - self.std_bias) * self.std_scale

return BaseModelOutputWithPast(last_hidden_state=hidden_states)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return BaseModelOutputWithPast(last_hidden_state=hidden_states)
return hidden_states

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've addressed this.

Comment on lines +1251 to +1253
attn_output = _gemma4_vision_attention_forward(
self, query_states, key_states, value_states, attention_mask
)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use MMEncoderAttention interface instead.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For MMEncoderAttention, I looked into it, but Gemma4 vision currently relies on a per-sample bool padding mask derived from pixel_position_ids == -1. MMEncoderAttention is built around packed/dense encoder attention with cu_seqlens and does not directly support this arbitrary padding mask. Converting it correctly would require an additional pack/scatter path for valid patch tokens, so I kept the SDPA path in this PR and would prefer to handle the MMEncoderAttention conversion as a follow-up!

Comment on lines +1111 to +1140
def _apply_multidimensional_rope(
hidden_states: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
position_ids: torch.Tensor,
unsqueeze_dim: int = 2,
) -> torch.Tensor:
ndim = position_ids.shape[-1]
num_input_channels = hidden_states.shape[-1]
num_rotated_channels_per_dim = 2 * (num_input_channels // (2 * ndim))
if num_rotated_channels_per_dim <= 0:
raise ValueError(
"Invalid configuration: num_rotated_channels_per_dim must be > 0, "
f"got {num_rotated_channels_per_dim}."
)

split_sizes = [num_rotated_channels_per_dim] * ndim
hidden_parts = torch.split(hidden_states, split_sizes, dim=-1)
cos_parts = torch.split(cos, split_sizes, dim=-1)
sin_parts = torch.split(sin, split_sizes, dim=-1)
output_parts = [
_apply_rotary_pos_emb(
hidden_states=hidden_parts[k],
cos=cos_parts[k],
sin=sin_parts[k],
unsqueeze_dim=unsqueeze_dim,
)
for k in range(ndim)
]
return torch.cat(output_parts, dim=-1)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can use ApplyRotaryEmb:

# --8<-- [start:apply_rotary_emb]
@CustomOp.register("apply_rotary_emb")
class ApplyRotaryEmb(CustomOp):
# --8<-- [end:apply_rotary_emb]
def __init__(

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion. I looked into using ApplyRotaryEmb here, but kept the HF-style multidimensional RoPE helper for now.

This is similar to the MMEncoderAttention case: Gemma4 vision currently uses batch-specific 2D pixel_position_ids with padded image patches, while the ApplyRotaryEmb path expects a more standard/packed layout. Switching this correctly would require additional layout changes, which I think is better handled as a follow-up.

For this PR, I kept the RoPE behavior aligned with Transformers and focused on moving the vision tower to vLLM-native modules for future LoRA support.

Comment thread vllm/model_executor/models/gemma4_mm.py Outdated
position_ids=pixel_position_ids,
)

return BaseModelOutputWithPast(last_hidden_state=hidden_states)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return BaseModelOutputWithPast(last_hidden_state=hidden_states)
return last_hidden_states

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've addressed this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants