Skip to content

[Bug] Gemma4 VLM from_pretrained(torch_dtype=float32) leaves nested submodules in bf16 and breaks FSDP2 #2017

@jQizhang

Description

@jQizhang

Summary

NeMoAutoModelForImageTextToText.from_pretrained(..., torch_dtype=torch.float32) does not build google/gemma-4-E4B-it with uniform float32 parameters.

The top-level HF config is changed to float32, but nested Gemma4 configs (text_config, vision_config, audio_config) keep the checkpoint dtype (bfloat16). As a result, modules built through AutoModel.from_config(sub_config) become bfloat16, while directly constructed modules such as embed_vision, embed_audio, and lm_head become float32.

FSDP2 then fails on the first forward:

AssertionError: FSDP expects uniform original parameter dtype but got {torch.bfloat16, torch.float32}

Reproduction

Start from examples/vlm_finetune/gemma4/gemma4_4b_mock.yaml:

-  torch_dtype: torch.bfloat16
+  torch_dtype: torch.float32
   ...
 freeze_config:
-  freeze_vision_tower: true
+  freeze_vision_tower: false

Run:

torchrun --nproc-per-node=8 --nnodes=1 \
  nemo_automodel/recipes/vlm/finetune.py \
  -c examples/vlm_finetune/gemma4/gemma4_4b_mock_vlm_repro.yaml

Observed:

File "torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py", line 238, in _init_mp_dtypes
AssertionError: FSDP expects uniform original parameter dtype but got {torch.bfloat16, torch.float32}

Root Cause

nemo_automodel/_transformers/model_init.py currently applies the requested dtype only to the top-level config:

if torch_dtype != "auto":
    hf_config.torch_dtype = torch_dtype

For Gemma4 VLM this misses nested configs used by HF during construction, for example google/gemma-4-E4B-it:

  • language_model, vision_tower, and audio_tower use AutoModel.from_config(sub_config) and read the nested torch_dtype.
  • embed_vision, embed_audio, and lm_head are direct nn.Linear modules and inherit torch.get_default_dtype().

Temporary dtype probes show the mismatch exists immediately after construction:

requested torch_dtype:      torch.float32
top.torch_dtype:           torch.float32
text_config.torch_dtype:   torch.bfloat16
vision_config.torch_dtype: torch.bfloat16
audio_config.torch_dtype:  torch.bfloat16

after_construct: [('torch.bfloat16', 1158), ('torch.float32', 3)]
float32 params:
  model.embed_vision.embedding_projection.weight
  model.embed_audio.embedding_projection.weight
  lm_head.weight

The checkpoint itself is uniformly BF16; the mixed dtype is introduced during model construction.

Expected Behavior

from_pretrained(torch_dtype=X) should construct all parameters with dtype X, including parameters controlled by nested multimodal sub-configs. With torch_dtype=torch.float32, FSDP2's uniform original-dtype check should pass.

Possible Fix

One possible fix is to propagate explicit dtype overrides into nested multimodal configs before constructing the model:

if torch_dtype != "auto":
    hf_config.torch_dtype = torch_dtype

    for sub_config_name in ("text_config", "vision_config", "audio_config"):
        sub_config = getattr(hf_config, sub_config_name, None)
        if sub_config is not None and hasattr(sub_config, "torch_dtype"):
            sub_config.torch_dtype = torch_dtype

With this local fix, the same reproducer constructs uniformly:

text_config.torch_dtype:   torch.float32
vision_config.torch_dtype: torch.float32
audio_config.torch_dtype:  torch.float32

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions