Summary
NeMoAutoModelForImageTextToText.from_pretrained(..., torch_dtype=torch.float32) does not build google/gemma-4-E4B-it with uniform float32 parameters.
The top-level HF config is changed to float32, but nested Gemma4 configs (text_config, vision_config, audio_config) keep the checkpoint dtype (bfloat16). As a result, modules built through AutoModel.from_config(sub_config) become bfloat16, while directly constructed modules such as embed_vision, embed_audio, and lm_head become float32.
FSDP2 then fails on the first forward:
AssertionError: FSDP expects uniform original parameter dtype but got {torch.bfloat16, torch.float32}
Reproduction
Start from examples/vlm_finetune/gemma4/gemma4_4b_mock.yaml:
- torch_dtype: torch.bfloat16
+ torch_dtype: torch.float32
...
freeze_config:
- freeze_vision_tower: true
+ freeze_vision_tower: false
Run:
torchrun --nproc-per-node=8 --nnodes=1 \
nemo_automodel/recipes/vlm/finetune.py \
-c examples/vlm_finetune/gemma4/gemma4_4b_mock_vlm_repro.yaml
Observed:
File "torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py", line 238, in _init_mp_dtypes
AssertionError: FSDP expects uniform original parameter dtype but got {torch.bfloat16, torch.float32}
Root Cause
nemo_automodel/_transformers/model_init.py currently applies the requested dtype only to the top-level config:
if torch_dtype != "auto":
hf_config.torch_dtype = torch_dtype
For Gemma4 VLM this misses nested configs used by HF during construction, for example google/gemma-4-E4B-it:
language_model, vision_tower, and audio_tower use AutoModel.from_config(sub_config) and read the nested torch_dtype.
embed_vision, embed_audio, and lm_head are direct nn.Linear modules and inherit torch.get_default_dtype().
Temporary dtype probes show the mismatch exists immediately after construction:
requested torch_dtype: torch.float32
top.torch_dtype: torch.float32
text_config.torch_dtype: torch.bfloat16
vision_config.torch_dtype: torch.bfloat16
audio_config.torch_dtype: torch.bfloat16
after_construct: [('torch.bfloat16', 1158), ('torch.float32', 3)]
float32 params:
model.embed_vision.embedding_projection.weight
model.embed_audio.embedding_projection.weight
lm_head.weight
The checkpoint itself is uniformly BF16; the mixed dtype is introduced during model construction.
Expected Behavior
from_pretrained(torch_dtype=X) should construct all parameters with dtype X, including parameters controlled by nested multimodal sub-configs. With torch_dtype=torch.float32, FSDP2's uniform original-dtype check should pass.
Possible Fix
One possible fix is to propagate explicit dtype overrides into nested multimodal configs before constructing the model:
if torch_dtype != "auto":
hf_config.torch_dtype = torch_dtype
for sub_config_name in ("text_config", "vision_config", "audio_config"):
sub_config = getattr(hf_config, sub_config_name, None)
if sub_config is not None and hasattr(sub_config, "torch_dtype"):
sub_config.torch_dtype = torch_dtype
With this local fix, the same reproducer constructs uniformly:
text_config.torch_dtype: torch.float32
vision_config.torch_dtype: torch.float32
audio_config.torch_dtype: torch.float32
Summary
NeMoAutoModelForImageTextToText.from_pretrained(..., torch_dtype=torch.float32)does not buildgoogle/gemma-4-E4B-itwith uniformfloat32parameters.The top-level HF config is changed to
float32, but nested Gemma4 configs (text_config,vision_config,audio_config) keep the checkpoint dtype (bfloat16). As a result, modules built throughAutoModel.from_config(sub_config)becomebfloat16, while directly constructed modules such asembed_vision,embed_audio, andlm_headbecomefloat32.FSDP2 then fails on the first forward:
Reproduction
Start from
examples/vlm_finetune/gemma4/gemma4_4b_mock.yaml:Run:
Observed:
Root Cause
nemo_automodel/_transformers/model_init.pycurrently applies the requested dtype only to the top-level config:For Gemma4 VLM this misses nested configs used by HF during construction, for example google/gemma-4-E4B-it:
language_model,vision_tower, andaudio_toweruseAutoModel.from_config(sub_config)and read the nestedtorch_dtype.embed_vision,embed_audio, andlm_headare directnn.Linearmodules and inherittorch.get_default_dtype().Temporary dtype probes show the mismatch exists immediately after construction:
The checkpoint itself is uniformly
BF16; the mixed dtype is introduced during model construction.Expected Behavior
from_pretrained(torch_dtype=X)should construct all parameters with dtypeX, including parameters controlled by nested multimodal sub-configs. Withtorch_dtype=torch.float32, FSDP2's uniform original-dtype check should pass.Possible Fix
One possible fix is to propagate explicit dtype overrides into nested multimodal configs before constructing the model:
With this local fix, the same reproducer constructs uniformly: