Skip to content

Add Gemma 4 (26B-A4B) LLM and VLM bridge#3148

Merged
cuichenx merged 30 commits into
mainfrom
yuya/gemma4-vlm-bridge
May 6, 2026
Merged

Add Gemma 4 (26B-A4B) LLM and VLM bridge#3148
cuichenx merged 30 commits into
mainfrom
yuya/gemma4-vlm-bridge

Conversation

@yaoyu-33

@yaoyu-33 yaoyu-33 commented Apr 3, 2026

Copy link
Copy Markdown
Contributor

Summary

  • Adds full bridge, provider, recipes, and VLM model wrapper for Google's Gemma 4 26B-A4B sparse MoE architecture
  • Supports both the text-only LLM (google/gemma-4-26B-A4B) and the vision-language model (google/gemma-4-26B-A4B-it)
  • Extends the Gemma 3 bridge/VL infrastructure with Gemma 4-specific handling: dual local/global RoPE, MoE expert routing, fused router weights, shared expert pre-norm, sliding-window attention, and K=V tying on global attention layers

New Files

File Description
gemma/gemma4_bridge.py LLM weight mapping — fused router, shared expert pre-norm, QKV/GatedMLP
gemma/gemma4_provider.py Megatron-Core GPT provider — proportional RoPE for global layers, MoE config, logit soft-capping
gemma_vl/gemma4_vl_bridge.py VLM bridge — vision tower + embedder mappings; K=V tying enforcement; LoRA adapter export with ABSENT_PROJECTION sentinel for absent v_proj on global layers
gemma_vl/gemma4_vl_provider.py VLM provider extending Gemma4 LLM provider
gemma_vl/modeling_gemma4_vl.py VLM model — HF vision encoder + Megatron language decoder; bidirectional attention mask for image token regions
recipes/gemma4_vl/gemma4_vl.py SFT and LoRA PEFT recipe configs
examples/models/vlm/gemma4_vl/ Conversion, inference, SFT, and LoRA scripts + Slurm job scripts + README

Key Design Decisions

  • Proportional RoPE: Global attention layers use head_dim=512 (not the standard dim=128). Fixed via Gemma4RotaryEmbedding override.
  • K=V tying on global layers: HF global attention has no v_proj — K and V share the same projection. At checkpoint import the bridge copies K rows into V. During fine-tuning Gemma4SelfAttention.get_query_key_value_tensors enforces value = key every forward pass. During LoRA adapter export _split_qkv_linear_out_weight returns ABSENT_PROJECTION for v_proj on global layers so the export loop skips it explicitly (a missing key raises KeyError; only an explicit ABSENT_PROJECTION value silently skips).
  • Bidirectional attention for image tokens: _compute_attention_mask builds a causal + bidirectional mask for contiguous image token groups, matching HF behaviour.
  • Vision pipeline: vision_tower.forward() returns last_hidden_state (already pooled + standardized), then embed_vision projects to language hidden dim — matches HF's Gemma4Model.get_image_features.

Validation Results

Check Result
Text-only cosine similarity (single GPU, bf16) 0.9998
VLM cosine similarity 0.9977
Same top-1 token
Image understanding ("red square")
HF → Megatron → HF round-trip (max weight diff) < 1e-6
SFT teacher-forced token accuracy (CORD-v2, 100 iters) ~98%
LoRA merge ✅ (EXIT=0)
LoRA adapter export (180/180 weights, ~3.7 GB) ✅ (EXIT=0)

Expected training dynamics: W&B report

Test Plan

  • Single-GPU logit parity (HF vs Megatron, text-only and VLM)
  • HF → Megatron → HF round-trip weight diff < 1e-6
  • Unit tests: bridge weight mapping, provider config, VLM modeling (tests/unit_tests/models/gemma/, tests/unit_tests/models/gemma_vl/)
  • SFT training convergence on CORD-v2 (teacher-forced accuracy ~98% at 100 steps)
  • LoRA PEFT training, merge, and adapter export
  • Multi-GPU TP/PP/EP CI tests

Requirements

Requires transformers>=5.5.0:

uv pip install -q --upgrade 'transformers>=5.5.0'

🤖 Generated with Claude Code

Add bridge, provider, and VLM model wrapper for Google's Gemma 4
MoE architecture (gemma-4-26B-A4B).

Key components:
- gemma4_bridge.py: Weight mapping with fused router weights,
  shared expert pre-norm fusion, and QKV/GatedMLP mappings
- gemma4_provider.py: Megatron-Core GPT provider with dual
  local/global RoPE (proportional formula for global layers),
  MoE expert routing, sliding window attention, and logit
  soft-capping
- gemma4_vl_bridge.py: VLM bridge with vision tower + embedder
  weight mappings (model.* prefix for raw safetensors keys)
- gemma4_vl_provider.py: VLM provider extending Gemma4 LLM
- modeling_gemma4_vl.py: VLM model combining HF vision tower
  with Megatron language model

Validation (single-GPU, bf16):
- Text-only cosine similarity: 0.9998
- VLM cosine similarity: 0.9977 (causal-only, apples-to-apples)
- Same top-1 token prediction as HF
- Correct image understanding ("Red square" from test image)

Known limitations:
- Bidirectional attention for image tokens (mm_token_type_ids)
  not yet implemented — uses causal-only masking
- Requires transformers >= 5.6.0.dev0 for Gemma4 support

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
@copy-pr-bot

copy-pr-bot Bot commented Apr 3, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@yaoyu-33 yaoyu-33 added area:model Model implementations and HF bridge logic feature New capabilities, enhancements, or enablement work labels Apr 6, 2026
Introduce new scripts for HuggingFace reference inference and conversion to Megatron for the Gemma 4 VLM model. The added files include:
- `hf_gemma4_reference_vlm.py`: A script for performing inference using HuggingFace's Gemma 4 VLM model, supporting both base and instruction-tuned versions.
- `hf_to_megatron_generate_gemma4.py`: A script for generating outputs from the Gemma 4 model in a Megatron-compatible format.
- `conversion.sh`, `inference.sh`, `peft.sh`, `sft.sh`, `slurm_peft.sh`, and `slurm_sft.sh`: Shell scripts for managing conversion, inference, and training processes, including support for LoRA fine-tuning and distributed training setups.

Signed-off-by: Ao Tang <aot@nvidia.com>
- Updated `hf_to_megatron_generate_gemma4.py` to support instruction-tuned model and improved image input handling.
- Modified `hf_to_megatron_generate_vlm.py` to include new image position IDs and refined vision input processing.
- Enhanced `conversion.sh` and `inference.sh` scripts to accommodate both base and instruction-tuned models for seamless inference.
- Added a comprehensive README for Gemma 4 VLM examples, detailing usage, requirements, and inference procedures.
- Introduced `Gemma4TiedKVMixin` to enforce K=V tying in global attention layers, ensuring compatibility with HuggingFace checkpoints.

Signed-off-by: Ao Tang <aot@nvidia.com>
…ya/gemma4-vlm-bridge

Signed-off-by: Ao Tang <aot@nvidia.com>
Signed-off-by: Ao Tang <aot@nvidia.com>
Signed-off-by: Ao Tang <aot@nvidia.com>
@suiyoubi suiyoubi marked this pull request as ready for review April 24, 2026 19:19
- Deleted `hf_to_megatron_generate_gemma4.py` as it is replaced by the unified `hf_to_megatron_generate_vlm.py` script.
- Updated `hf_to_megatron_generate_vlm.py` to handle Gemma 4 specific input preprocessing and attention mask adjustments.
- Enhanced `process_image_inputs` to support image position IDs for improved model compatibility.
- Revised inference and training scripts to reflect the new VLM structure and ensure seamless integration.

Signed-off-by: Ao Tang <aot@nvidia.com>
Signed-off-by: Ao Tang <aot@nvidia.com>
- Removed outdated transformer upgrade instructions from `conversion.sh`.
- Updated `README.md` to clarify transformer upgrade steps and changed example image for expected output.
- Enhanced `peft_bridge.py` with a sentinel for absent projections to improve adapter export handling.
- Modified `gemma4_bridge.py` and `gemma4_vl_bridge.py` to correctly handle global vs sliding attention layers during weight splitting.

Signed-off-by: Ao Tang <aot@nvidia.com>
Signed-off-by: Ao Tang <aot@nvidia.com>
- Bump transformers version to 5.5.0 in pyproject.toml and uv.lock.
- Add mistral-common dependency with version >=1.10.0 in both files.
- Update uv.lock to include mistral-common package details and its dependencies.

Signed-off-by: Ao Tang <aot@nvidia.com>
@suiyoubi

Copy link
Copy Markdown
Contributor

/ok to test e726f8b

Signed-off-by: Ao Tang <aot@nvidia.com>
@suiyoubi

Copy link
Copy Markdown
Contributor

/ok to test 296d0ed

Signed-off-by: Ao Tang <aot@nvidia.com>
@suiyoubi

Copy link
Copy Markdown
Contributor

/ok to test 6cfc877

Signed-off-by: Ao Tang <aot@nvidia.com>
@suiyoubi

Copy link
Copy Markdown
Contributor

/ok to test bbeec27

Signed-off-by: Ao Tang <aot@nvidia.com>
Signed-off-by: Ao Tang <aot@nvidia.com>
cuichenx
cuichenx previously approved these changes May 4, 2026

@cuichenx cuichenx left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

)


def _process_gemma4_inputs(processor, image_path, prompt):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(note to self, not blocking PR merge) these model specific process functions should be unified

Signed-off-by: Chen Cui <chcui@nvidia.com>
@cuichenx cuichenx added the ready-to-merge PR is approved, current, and only waiting for CI to pass before merge label May 4, 2026
@cuichenx

cuichenx commented May 4, 2026

Copy link
Copy Markdown
Contributor

/ok to test 81e1c5d

cuichenx
cuichenx previously approved these changes May 4, 2026
@weijiac0619

Copy link
Copy Markdown
Contributor

@suiyoubi could you fix the failed test? Thank you

@cuichenx

cuichenx commented May 5, 2026

Copy link
Copy Markdown
Contributor

error was on main

@cuichenx

cuichenx commented May 5, 2026

Copy link
Copy Markdown
Contributor

/ok to test 672675f

Signed-off-by: Ao Tang <aot@nvidia.com>
@suiyoubi

suiyoubi commented May 5, 2026

Copy link
Copy Markdown
Contributor

/ok to test aa3ec70

Signed-off-by: Ao Tang <aot@nvidia.com>
@suiyoubi

suiyoubi commented May 5, 2026

Copy link
Copy Markdown
Contributor

/ok to test ed709a0

Signed-off-by: Ao Tang <aot@nvidia.com>
@suiyoubi

suiyoubi commented May 5, 2026

Copy link
Copy Markdown
Contributor

/ok to test feaac92

@cuichenx cuichenx merged commit 2faedbf into main May 6, 2026
89 checks passed
@cuichenx cuichenx deleted the yuya/gemma4-vlm-bridge branch May 6, 2026 00:54
gautham-kollu pushed a commit that referenced this pull request May 12, 2026
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: Ao Tang <aot@nvidia.com>
Signed-off-by: Chen Cui <chcui@nvidia.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Ao Tang <aot@nvidia.com>
Co-authored-by: Chen Cui <chcui@nvidia.com>
vasunvidia pushed a commit to vasunvidia/Megatron-Bridge that referenced this pull request Jun 10, 2026
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: Ao Tang <aot@nvidia.com>
Signed-off-by: Chen Cui <chcui@nvidia.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Ao Tang <aot@nvidia.com>
Co-authored-by: Chen Cui <chcui@nvidia.com>
Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:model Model implementations and HF bridge logic feature New capabilities, enhancements, or enablement work ready-to-merge PR is approved, current, and only waiting for CI to pass before merge

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants