Skip to content

feat: add Gemma4 support#2224

Open
sharonyu-115 wants to merge 30 commits into
NVIDIA-NeMo:mainfrom
sharonyu-115:gemma4-support
Open

feat: add Gemma4 support#2224
sharonyu-115 wants to merge 30 commits into
NVIDIA-NeMo:mainfrom
sharonyu-115:gemma4-support

Conversation

@sharonyu-115

@sharonyu-115 sharonyu-115 commented Apr 7, 2026

Copy link
Copy Markdown
Contributor

What does this PR do ?

Adds Gemma 4 support to NeMo-RL with DAPO and GRPO recipes across dense and MoE variants, plus a VLM recipe.

Issue

List issues that this PR closes:
#2212

Summary of code changes:

Core source changes (5 files)

nemo_rl/models/policy/utils.py

Registers gemma4AutoModelForImageTextToText in both the HF and
NeMo-AutoModel factory maps (2-line addition). Makes Gemma4 load as a
VLM-capable architecture.

nemo_rl/models/generation/vllm/vllm_worker.py

Adds Gemma4ForConditionalGeneration to the two arch lists that gate special
vLLM handling (alongside Gemma3 and Qwen3.5).

nemo_rl/models/automodel/train.py

Two Gemma4 forward-path fixes:

  • New _needs_kv_cache_for_shared_layers() helper + threading a use_cache
    arg through model_forward / forward_with_post_processing_fn. Gemma4 E2B
    uses KV-sharing (num_kv_shared_layers > 0) and needs use_cache=True so
    DynamicCache feeds K/V from anchor layers to shared layers — otherwise
    shared layers fall back to untrained projections and produce garbage.
    Temporary; removable after transformers ≥5.5.2 (HF #45312).
  • Injects mm_token_type_ids (zeros) for model_type == "gemma4" even on
    text-only inputs, mirroring existing Gemma3 token_type_ids handling.

nemo_rl/models/automodel/setup.py

  • New _disable_automodel_checkpoint_dtype_restore(): monkeypatches
    Automodel's _restore_loaded_model_dtype to a no-op so NeMo-RL's
    torch_dtype=float32 master-weight load isn't silently downcast back to
    bf16 (which broke AdamW — the nano-v2-12b reward-stuck-at-0.18 bug).
    Added unit test to detect when the patch is on longer needed.
  • Removes the unconditional visual-encoder freeze for text-only training —
    now handled declaratively via recipe freeze_config. (This is the
    behavioral change flagged in PR review, now documented in
    docs/model-quirks.md.)

nemo_rl/models/automodel/checkpoint.py

New _patch_qwen_vl_vision_key_mapping() import-time monkeypatch re-adding the
^visual.model.visual. key rename that transformers 5.5.0 dropped
(regression from #44627, fixed upstream in 5.6 via #45358). Without it the
Qwen2.5-VL vision tower loads randomly initialized → token_mult_prob_error.
Idempotent, escape-hatch env var (NRL_DISABLE_QWENVL_VISION_PATCH), removal
tripwire test.


Dependencies

  • pyproject.toml:
    • transformers floor raised 5.3.05.5.0 (base extra), automodel extra
      pinned >=5.5.0,<5.6.0 (Gemma4 requires it)
    • adds mistral-common>=1.11.0
    • bumps DeepEP bfded34829d31c09 across all 4 extras + metadata label to be compatible of nvshmem which got updated in an earlier PR.
  • Automodel submodule: 92635e746de0c361
  • uv.lock: regenerated to match

New recipes (5)

DAPO E2B / 26B-A4B / 31B (LLM) + VLM GRPO E4B on geo3k, with matching launcher
.sh scripts under tests/test_suites/llm/ and .../vlm/.

Modified existing recipes (7)

These are not Gemma4 recipes — they fall into three change categories:

(a) router_aux_loss_coef: 00.0 — HF strict config validation declares
this field as float and rejects YAML integer 0. Applied to 4 MoE recipes:

  • llm/dpo-nanov3-30B3AB-1n8g-fsdp8ep8-automodel.yaml
  • llm/grpo-nanov3-30BA3B-2n8g-fsdp2.yaml
  • llm/sft-nanov3-30BA3B-2n8g-fsdp2.yaml
  • llm/sft-gpt-oss-20b-1n8g-fsdp8ep8-automodel.yaml

(b) BackendConfig key migration + freeze_config — replaces the old
enable_deepep: true with the new experts: gmm / dispatcher: deepep keys
(matches Automodel 6de0c361 BackendConfig API) and adds a freeze_config
block (vision/audio frozen, language model trained) to replace the removed
implicit visual-encoder freeze. Applied to 3 Qwen3.5 recipes:

  • llm/grpo-qwen3.5-35ba3b-2n8g-automodel-ep16.yaml
  • llm/grpo-qwen3.5-35ba3b-dapo-4n8g-automodel.yaml (also reorders/quotes
    truncated_importance_sampling_type: "tis")
  • vlm/vlm_grpo-qwen3.5-35ba3b-geo3k-2n8g-automodel-ep16.yaml (vision tower
    not frozen — it's a VLM run)

(c) OOM fixdpo-nanov3-30B3AB-1n8g-fsdp8ep8-automodel.yaml additionally
sets dtensor_cfg.activation_checkpointing: true to avoid the full-vocab
log_softmax OOM in DPO loss-input prep (~step 8) on 80 GiB GPUs. (This file
gets both (a) and (c).)

Tests & docs

  • Nightly/release suites: nightly.txt (+9), release.txt (+4) register
    the new tests (31B moved to release with a GPU-hour cap bump).
    tests/unit/test_recipes_and_test_suites.py and
    tests/unit/environments/test_reward_model_environment.py adjusted.
  • 3 new unit-test files under tests/unit/models/automodel/ covering the
    checkpoint patch, the dtype-restore workaround, and the KV-cache/shared-layer
    logic — each with a removal tripwire so the workarounds get cleaned up when
    upstream fixes land.
  • docs/model-quirks.md: new freeze_config caveats section.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

Training metrics for reference

E2B-it DAPO
image

31B-it DAPO:
image

MoE 26B-A4B-it DAPO:
image

@copy-pr-bot

copy-pr-bot Bot commented Apr 7, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@sharonyu-115

Copy link
Copy Markdown
Contributor Author

/ok to test b3b4d3c

@sharonyu-115 sharonyu-115 added the CI:L1 Run doctests, unit tests, and functional tests label Apr 8, 2026
@zpqiu zpqiu changed the title Gemma4 support feat: add Gemma4 support Apr 8, 2026
@zpqiu zpqiu added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Apr 8, 2026
@sharonyu-115 sharonyu-115 added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Apr 8, 2026
@zpqiu zpqiu marked this pull request as ready for review April 8, 2026 05:36
@zpqiu zpqiu requested review from a team as code owners April 8, 2026 05:36
@zpqiu zpqiu added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Apr 8, 2026
@zpqiu zpqiu marked this pull request as draft April 8, 2026 05:37
@zpqiu

zpqiu commented Apr 8, 2026

Copy link
Copy Markdown
Contributor

/ok to test 360cb8a

@sharonyu-115

Copy link
Copy Markdown
Contributor Author

/ok to test 7353904

@sharonyu-115

Copy link
Copy Markdown
Contributor Author

/ok to test e90e80c

@sharonyu-115

Copy link
Copy Markdown
Contributor Author

/ok to test 04fc41c

@sharonyu-115

Copy link
Copy Markdown
Contributor Author

/ok to test 9d9fd36

sharonyu-115 and others added 27 commits June 13, 2026 20:49
Replaces the inline visual-encoder freeze in automodel/setup.py
(removed in the gemma4 feature commit) with explicit YAML freeze_config
entries on the Qwen3.5 recipes that share the same code path.

Co-authored-by: jQizhang <jqizhang@users.noreply.github.com>
Signed-off-by: Shuang Yu <shuangy@nvidia.com>
Nightly test scripts for E2B-it, 26B-A4B-it, and 31B-it DAPO recipes,
and the VLM E4B GRPO recipe. Includes 31B DAPO threshold calibration
from off-policy baseline.

Co-authored-by: jQizhang <jqizhang@users.noreply.github.com>
Signed-off-by: Shuang Yu <shuangy@nvidia.com>
…; enable Liger for Gemma4-31B

All four DAPO-style recipes had truncated_importance_sampling_ratio: 2
set but inherited truncated_importance_sampling_type: null from
grpo_math_1B.yaml. nemo_rl/algorithms/loss/loss_functions.py raises
ValueError("Invalid truncated importance sampling type: None") on the
first training step in that case. Setting the type to "tis" matches the
ratio=2 semantics (clamp IS weights to [0, max]) and unblocks the
recipes.

Affected:
- dapo-gemma4-e2b-it-1n8g-fsdp2-automodel.yaml
- dapo-gemma4-26ba4b-it-4n8g-fsdp2-automodel.yaml
- dapo-gemma4-31b-it-4n8g-fsdp2-automodel.yaml
- grpo-qwen3.5-35ba3b-dapo-4n8g-automodel.yaml

For Gemma4-31B specifically, also add use_liger_kernel: true and
use_sdpa_patching: false under dtensor_cfg.automodel_kwargs, following
the upstream Automodel example at
3rdparty/Automodel-workspace/Automodel/examples/vlm_finetune/gemma4/gemma4_31b.yaml
(though Liger needs a follow-up to ensure liger-kernel is in the policy
worker venv).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Shuang Yu <shuangy@nvidia.com>
Signed-off-by: larkzhang-nv <larkz@nvidia.com>
Signed-off-by: larkzhang-nv <larkz@nvidia.com>
Signed-off-by: larkzhang-nv <larkz@nvidia.com>
…qwen3.5 DAPO recipe

Signed-off-by: Shuang Yu <shuangy@nvidia.com>
…est -x

Signed-off-by: Shuang Yu <shuangy@nvidia.com>
Nightly total with all 3 Gemma4 suites was 1501 GPU-h, over the 1360 cap.
Move the 31B 4n8g run (128 GPU-h) to release alongside 26ba4b; keep E2B
(12 GPU-h) and VLM E4B (16 GPU-h) in nightly. Total drops to 1373 GPU-h.
Bump cap from 1360 to 1380 to give 7 GPU-h headroom.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Shuang Yu <shuangy@nvidia.com>
Remove train_global_batch_size that equals the base default and unquote
the tis literal, so the configs-minimize-check pre-commit hook passes.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Shuang Yu <shuangy@nvidia.com>
Tighten the automodel extra's transformers constraint to >=5.5.0,<5.6.0
(nemo-automodel pins ==5.5.0 exactly, so this has no resolution effect
but prevents uv from drifting to a newer fork in a fresh Docker build).

Regenerate uv.lock with submodules at their tracked commits. The previous
lock was generated with a locally-modified Megatron-Bridge (ahead of the
tracked commit), which caused its Megatron-LM to inject a custom
nvidia-resiliency-ext pin (15a851565) that only existed in the local
working tree. In CI, Docker initializes submodules to their tracked
commits (no custom pin), so uv resolved nvidia-resiliency-ext to the
root pin (6c5f2a13), producing a lock mismatch and a container build
failure.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Shuang Yu <shuangy@nvidia.com>
The transformers 5.3.0 -> 5.5.0 bump in this PR slightly shifts the
logits of Skywork/Skywork-Reward-V2-Qwen3-0.6B, moving the first reward
from -5.2500 to -5.4062 (delta 0.156 > atol=1e-1). Update the hardcoded
baseline to the values produced under transformers 5.5.0.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Shuang Yu <shuangy@nvidia.com>
…ormers 5.5.x

transformers #44627 (v5.5.0) centralized VLM checkpoint key-conversions and
dropped the Qwen2.5-VL `^visual -> model.visual` rename; transformers #45358
restored it only in v5.6.0. NeMo-RL pins transformers<5.6.0, and Automodel's
get_combined_key_mapping only mirrors transformers WeightRenaming entries, so on
v5.5.x the vision-tower checkpoint keys (visual.*) stay unmapped and are dropped
by FSDP2 set_model_state_dict(strict=False) in load_base_model -> the vision
tower is left randomly initialized. The training forward then diverges from
vLLM (correct vision), producing the vlm_grpo token_mult_prob_error CI failure.

Wrap get_combined_key_mapping to re-inject `^visual\. -> model.visual.` for
qwen2_5_vl/qwen2_vl. Idempotent (skips if a model.visual rule already exists),
so it auto-noops on transformers >=5.6.0 or once Automodel adopts PR NVIDIA-NeMo#2431.
Disable via NRL_DISABLE_QWENVL_VISION_PATCH=1.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Shuang Yu <shuangy@nvidia.com>
Add a temporary Nemotron-H workaround that casts Mamba out_proj inputs to the projection weight dtype.

Automodel PR NVIDIA-NeMo#1631 preserves BF16 checkpoint dtypes via _restore_loaded_model_dtype, which exposes the transformers Nemotron-H cuda_kernels_forward bug fixed upstream by transformers PR #46487.

Apply the hook only for Nemotron-H configs and make it opt-out via NRL_DISABLE_NEMOTRON_H_DTYPE_PATCH=1.

Signed-off-by: Shuang Yu <shuangy@nvidia.com>
Use 0.0 instead of 0 for router_aux_loss_coef overrides.

Hugging Face strict config validation declares router_aux_loss_coef as a float and rejects YAML's integer 0. Add comments beside the overrides to avoid regressing this back to an int.

Signed-off-by: Shuang Yu <shuangy@nvidia.com>
Clarify that transformers PR #44627 introduced the v5.5.0 Qwen-VL visual key mapping regression and transformers PR #45358 fixed it in v5.6.

The local patch remains necessary while NeMo-RL pins to an Automodel commit that still depends on transformers v5.5.0. It can be removed after Automodel upgrades its transformers dependency to include #45358.

Signed-off-by: Shuang Yu <shuangy@nvidia.com>
Signed-off-by: Shuang Yu <shuangy@nvidia.com>
The previous commit bumped the DeepEP git rev to 29d31c095 in all
dependency lists but left the [[tool.uv.dependency-metadata]] version
label and the lock at the old commit's short SHA (1.2.1+bfded34).
DeepEP's setup.py derives its version as 1.2.1+<git short HEAD>, so the
new commit builds as 1.2.1+29d31c0. CI's fresh `uv sync --locked` build
computed 1.2.1+29d31c0 and rejected the stale lock; local `uv lock`
reused the static override and masked the mismatch.

Update the override to v1.2.1+29d31c0 and regenerate the lock.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Shuang Yu <shuangy@nvidia.com>
- train.py: remove the obsolete use_cache/activation-checkpointing
  incompatibility note. Automodel NVIDIA-NeMo#1705 (pinned 6de0c361) keeps use_cache=True
  for KV-sharing models under activation checkpointing, so the E4B VLM recipe's
  activation_checkpointing: true is safe.
- dtensor_policy_worker.py (v1): remove the Gemma4 mm_token_type_ids injection.
  The v1 DTensor worker is being deprecated; all shipped Gemma4 recipes use
  _v2: true, which threads use_cache/mm_token_type_ids correctly.
- setup.py: drop the Nemotron-H projection-dtype patch. A module forward-hook
  cannot reach the fused Mamba kernel's internal out_proj F.linear, so it cannot
  make nemotron-h LoRA train; the proper fix is the Automodel r0.5.0 restore-dtype
  change (tracked as a separate migration).
- recipes: migrate enable_deepep: true -> experts: gmm + dispatcher: deepep for
  the gemma4/qwen3.5 automodel recipes (enable_deepep is deprecated in Automodel
  BackendConfig; behavior-preserving). Verified: 26B-A4B trains 20 steps, gen_kl
  0.0009, gates pass.
- tests: harden the E4B VLM gate with median(token_mult_prob_error) < 1.05
  (observed 1.011 in CI); add a reward-ordering invariant to the reward-model
  env test; add hermetic unit tests for _needs_kv_cache_for_shared_layers and the
  Gemma4 mm_token_type_ids injection.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Shuang Yu <shuangy@nvidia.com>
NVIDIA-NeMo#2419 workaround)

Automodel's _restore_loaded_model_dtype (HF/force_hf load path) re-casts loaded
params back to the bf16 checkpoint dtype, silently undoing NeMo-RL's intended
torch_dtype=float32 master-weight load. With bf16 master weights, AdamW updates
underflow and the policy never learns: grpo-nano-v2-12b reward[30] stuck ~0.18
(vs ~0.54) and sft-nanov3-30BA3B loss plateaus. Only force_hf models (NemotronH
nano-v2/nano-v3) are affected; custom-impl models (gemma4, Llama) load via the DCP
copy path that preserves fp32.

Add _disable_automodel_checkpoint_dtype_restore() to no-op that restore before
from_pretrained so the requested fp32 is honored. Validated: nano-v2-12b reward[30]
0.176 -> 0.541 PASS; nanov3-30BA3B-lora loss[20] 2.027 PASS.

This is temporary until the automodel pin includes NVIDIA-NeMo/Automodel#2419
(rewrites _restore_loaded_model_dtype to honor an explicit torch_dtype). Add an
obsolescence tripwire test that fails when NVIDIA-NeMo#2419 lands so the workaround is removed
timely, plus an analogous tripwire for the existing Qwen-VL vision-tower
key-mapping workaround (fires when transformers #45358 / >=5.6 reaches the pin).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Shuang Yu <shuangy@nvidia.com>
Reapply gemma4 dependency overrides (transformers 5.5.0, vllm 0.20.0,
deep_ep 29d31c09) on top of upstream's lock baseline, which now carries
the PyJWT/mlflow CVE bumps (NVIDIA-NeMo#2752). Resolved 445 packages in-container.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Shuang Yu <shuangy@nvidia.com>
The `ruff check --select I --fix` pre-commit hook reorders imports in
test_automodel_checkpoint.py and test_automodel_setup.py: third-party
`nemo_automodel` must precede first-party `nemo_rl`, and straight
`import` precedes `from ... import`. Apply the fixes so the lint check
(and the downstream CI quality check) pass.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Shuang Yu <shuangy@nvidia.com>
The automodel fp32-master-weight tripwire test
(test_automodel_dtype_restore_workaround_still_needed) failed in CI as a
false positive. _disable_automodel_checkpoint_dtype_restore() globally and
irreversibly replaces _restore_loaded_model_dtype with a no-op; earlier
setup_model_and_optimizer tests in the same process leave that no-op
installed, so the tripwire exercised the no-op (which preserves fp32)
instead of Automodel's real downgrading function. Stash the original on the
no-op and have the test recover it via _nrl_original.

Also pass requested_dtype=fp32 to the function when its signature accepts
it, so the tripwire actually fires once Automodel NVIDIA-NeMo#2419 is pinned: the
rewritten function honors the explicit fp32 request only via that new
parameter (promote_types), not via hf_config/load_kwargs.

Correct the Skywork reward baseline (-5.4062 -> -5.2500) to the value the
CI build produces (also the historical pre-refresh value); the
incorrect-answer score is sensitive to the transformers/torch/kernel build.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Shuang Yu <shuangy@nvidia.com>
Gemma4 nightly coverage (dapo-gemma4-e2b 12 GPU-hrs +
vlm_grpo-gemma4-e4b 16 GPU-hrs) pushes the nightly total to 1897,
over the 1890 cap. Bump the budget to 1900, following the established
pattern (most recently NVIDIA-NeMo#2777, 1820 -> 1890).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Shuang Yu <shuangy@nvidia.com>
…emoval

Add a unit test that fails once transformers>=5.5.2 (PR #45312) lands,
which fixes KV sharing without requiring use_cache=True and makes the
_needs_kv_cache_for_shared_layers workaround in
nemo_rl/models/automodel/train.py obsolete. The test keys on the
transformers version the workaround's TODO names; it only runs under
--automodel-only (transformers 5.5.0), so it stays green today and fires
exactly when the automodel pin advances past the fix.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Shuang Yu <shuangy@nvidia.com>
The dpo-nanov3-30B3AB-1n8g-fsdp8ep8-automodel recipe OOMs at ~step 8 on
80 GiB GPUs: the full-vocab log_softmax in DPO loss-input prep
(get_next_token_logprobs_from_logits) spikes ~3.7 GiB on an already
near-full budget. Enabling DTensor activation checkpointing frees enough
activation memory to clear it; validated end-to-end (15/15 steps, all
check_metrics thresholds pass, steady ~1.6s/step).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Shuang Yu <shuangy@nvidia.com>
Note the removal of the implicit visual-encoder freeze for text-only
training: AutoModel only freezes when freeze_config is present (no default
auto-freeze), and a typo in a freeze_* key silently falls back to unfrozen.
Both can produce optimizer state for never-grad'd params and a
checkpoint-resume key mismatch on custom configs.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Shuang Yu <shuangy@nvidia.com>
@github-actions

Copy link
Copy Markdown

✅ Submodule Fast-Forward Check Results

Check based on commit: f207dc2 (PR #2224 from gemma4-support)

✅ Submodules that are properly updated:

Automodel: ✅ PR branch is ahead of main branch (fast-forward)

All submodule changes look good! ✨

@sharonyu-115

Copy link
Copy Markdown
Contributor Author

/ok to test f207dc2

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:L1 Run doctests, unit tests, and functional tests Documentation Improvements or additions to documentation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants