Add Gemma 4 (26B-A4B) LLM and VLM bridge#3148
Merged
Merged
Conversation
Add bridge, provider, and VLM model wrapper for Google's Gemma 4
MoE architecture (gemma-4-26B-A4B).
Key components:
- gemma4_bridge.py: Weight mapping with fused router weights,
shared expert pre-norm fusion, and QKV/GatedMLP mappings
- gemma4_provider.py: Megatron-Core GPT provider with dual
local/global RoPE (proportional formula for global layers),
MoE expert routing, sliding window attention, and logit
soft-capping
- gemma4_vl_bridge.py: VLM bridge with vision tower + embedder
weight mappings (model.* prefix for raw safetensors keys)
- gemma4_vl_provider.py: VLM provider extending Gemma4 LLM
- modeling_gemma4_vl.py: VLM model combining HF vision tower
with Megatron language model
Validation (single-GPU, bf16):
- Text-only cosine similarity: 0.9998
- VLM cosine similarity: 0.9977 (causal-only, apples-to-apples)
- Same top-1 token prediction as HF
- Correct image understanding ("Red square" from test image)
Known limitations:
- Bidirectional attention for image tokens (mm_token_type_ids)
not yet implemented — uses causal-only masking
- Requires transformers >= 5.6.0.dev0 for Gemma4 support
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Introduce new scripts for HuggingFace reference inference and conversion to Megatron for the Gemma 4 VLM model. The added files include: - `hf_gemma4_reference_vlm.py`: A script for performing inference using HuggingFace's Gemma 4 VLM model, supporting both base and instruction-tuned versions. - `hf_to_megatron_generate_gemma4.py`: A script for generating outputs from the Gemma 4 model in a Megatron-compatible format. - `conversion.sh`, `inference.sh`, `peft.sh`, `sft.sh`, `slurm_peft.sh`, and `slurm_sft.sh`: Shell scripts for managing conversion, inference, and training processes, including support for LoRA fine-tuning and distributed training setups. Signed-off-by: Ao Tang <aot@nvidia.com>
- Updated `hf_to_megatron_generate_gemma4.py` to support instruction-tuned model and improved image input handling. - Modified `hf_to_megatron_generate_vlm.py` to include new image position IDs and refined vision input processing. - Enhanced `conversion.sh` and `inference.sh` scripts to accommodate both base and instruction-tuned models for seamless inference. - Added a comprehensive README for Gemma 4 VLM examples, detailing usage, requirements, and inference procedures. - Introduced `Gemma4TiedKVMixin` to enforce K=V tying in global attention layers, ensuring compatibility with HuggingFace checkpoints. Signed-off-by: Ao Tang <aot@nvidia.com>
…ya/gemma4-vlm-bridge Signed-off-by: Ao Tang <aot@nvidia.com>
Signed-off-by: Ao Tang <aot@nvidia.com>
Signed-off-by: Ao Tang <aot@nvidia.com>
Signed-off-by: Ao Tang <aot@nvidia.com>
- Deleted `hf_to_megatron_generate_gemma4.py` as it is replaced by the unified `hf_to_megatron_generate_vlm.py` script. - Updated `hf_to_megatron_generate_vlm.py` to handle Gemma 4 specific input preprocessing and attention mask adjustments. - Enhanced `process_image_inputs` to support image position IDs for improved model compatibility. - Revised inference and training scripts to reflect the new VLM structure and ensure seamless integration. Signed-off-by: Ao Tang <aot@nvidia.com>
…ya/gemma4-vlm-bridge
Signed-off-by: Ao Tang <aot@nvidia.com>
- Removed outdated transformer upgrade instructions from `conversion.sh`. - Updated `README.md` to clarify transformer upgrade steps and changed example image for expected output. - Enhanced `peft_bridge.py` with a sentinel for absent projections to improve adapter export handling. - Modified `gemma4_bridge.py` and `gemma4_vl_bridge.py` to correctly handle global vs sliding attention layers during weight splitting. Signed-off-by: Ao Tang <aot@nvidia.com>
Signed-off-by: Ao Tang <aot@nvidia.com>
…ya/gemma4-vlm-bridge
3 tasks
- Bump transformers version to 5.5.0 in pyproject.toml and uv.lock. - Add mistral-common dependency with version >=1.10.0 in both files. - Update uv.lock to include mistral-common package details and its dependencies. Signed-off-by: Ao Tang <aot@nvidia.com>
Contributor
|
/ok to test e726f8b |
Signed-off-by: Ao Tang <aot@nvidia.com>
Contributor
|
/ok to test 296d0ed |
Contributor
|
/ok to test 6cfc877 |
Signed-off-by: Ao Tang <aot@nvidia.com>
Contributor
|
/ok to test bbeec27 |
Signed-off-by: Ao Tang <aot@nvidia.com>
1ca65e2 to
4733d50
Compare
cuichenx
previously approved these changes
May 4, 2026
| ) | ||
|
|
||
|
|
||
| def _process_gemma4_inputs(processor, image_path, prompt): |
Contributor
There was a problem hiding this comment.
(note to self, not blocking PR merge) these model specific process functions should be unified
Signed-off-by: Chen Cui <chcui@nvidia.com>
Contributor
|
/ok to test 81e1c5d |
cuichenx
previously approved these changes
May 4, 2026
Contributor
|
@suiyoubi could you fix the failed test? Thank you |
Contributor
|
error was on main |
Contributor
|
/ok to test 672675f |
Signed-off-by: Ao Tang <aot@nvidia.com>
Contributor
|
/ok to test aa3ec70 |
Contributor
|
/ok to test ed709a0 |
Signed-off-by: Ao Tang <aot@nvidia.com>
Contributor
|
/ok to test feaac92 |
cuichenx
approved these changes
May 6, 2026
Merged
3 tasks
gautham-kollu
pushed a commit
that referenced
this pull request
May 12, 2026
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> Signed-off-by: Ao Tang <aot@nvidia.com> Signed-off-by: Chen Cui <chcui@nvidia.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Ao Tang <aot@nvidia.com> Co-authored-by: Chen Cui <chcui@nvidia.com>
vasunvidia
pushed a commit
to vasunvidia/Megatron-Bridge
that referenced
this pull request
Jun 10, 2026
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> Signed-off-by: Ao Tang <aot@nvidia.com> Signed-off-by: Chen Cui <chcui@nvidia.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Ao Tang <aot@nvidia.com> Co-authored-by: Chen Cui <chcui@nvidia.com> Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
google/gemma-4-26B-A4B) and the vision-language model (google/gemma-4-26B-A4B-it)New Files
gemma/gemma4_bridge.pygemma/gemma4_provider.pygemma_vl/gemma4_vl_bridge.pyABSENT_PROJECTIONsentinel for absentv_projon global layersgemma_vl/gemma4_vl_provider.pygemma_vl/modeling_gemma4_vl.pyrecipes/gemma4_vl/gemma4_vl.pyexamples/models/vlm/gemma4_vl/Key Design Decisions
head_dim=512(not the standarddim=128). Fixed viaGemma4RotaryEmbeddingoverride.v_proj— K and V share the same projection. At checkpoint import the bridge copies K rows into V. During fine-tuningGemma4SelfAttention.get_query_key_value_tensorsenforcesvalue = keyevery forward pass. During LoRA adapter export_split_qkv_linear_out_weightreturnsABSENT_PROJECTIONforv_projon global layers so the export loop skips it explicitly (a missing key raisesKeyError; only an explicitABSENT_PROJECTIONvalue silently skips)._compute_attention_maskbuilds a causal + bidirectional mask for contiguous image token groups, matching HF behaviour.vision_tower.forward()returnslast_hidden_state(already pooled + standardized), thenembed_visionprojects to language hidden dim — matches HF'sGemma4Model.get_image_features.Validation Results
Expected training dynamics: W&B report
Test Plan
tests/unit_tests/models/gemma/,tests/unit_tests/models/gemma_vl/)Requirements
Requires
transformers>=5.5.0:uv pip install -q --upgrade 'transformers>=5.5.0'🤖 Generated with Claude Code