Skip to content

fix(precision): dtype contract bug fixes for FSDP2 mixed-dtype loads#2419

Merged
akoumpa merged 13 commits into
mainfrom
yuhez/fix/dtype-bugfixes
Jun 8, 2026
Merged

fix(precision): dtype contract bug fixes for FSDP2 mixed-dtype loads#2419
akoumpa merged 13 commits into
mainfrom
yuhez/fix/dtype-bugfixes

Conversation

@yuhezhang-ai

@yuhezhang-ai yuhezhang-ai commented Jun 4, 2026

Copy link
Copy Markdown
Contributor

Summary

Split out of #2379 (the broader fp32-master-weight effort) so the bug-fix portion can merge soon, separate from the default-dtype behavior change (which needs broad OOM/config validation and will follow up).

This PR contains only changes that are no-ops for existing example configs (no config sets model.torch_dtype: float32, so there is no memory/perf regression). It fixes correctness issues in how dtypes are loaded, propagated, and sharded.

Bug fixes / refactors

  • _restore_loaded_model_dtype is now dtype-aware (model_init.py): unifies each floating tensor to promote_types(checkpoint, requested) instead of always restoring the checkpoint dtype. Honors an explicit fp32 request while preserving intrinsically-fp32 checkpoint params (e.g. A_log) under bf16; no-op for the bf16/auto path. Fixes FSDP2 uniform-dtype tripping on HF mixed-dtype loads.
  • FSDP compute dtype resolved per-param, decoupled from storage (parallelizer_utils.py, parallelizer.py): fully_shard_by_dtype groups params by resolved compute dtype with clear precedence (pinned fp32 → HF-recorded _hf_compute_dtype → fallback).
  • Frozen modules cast to compute dtype (models/common/utils.py, infrastructure.py): general cast_frozen_modules_to_compute_dtype casts fully-frozen, non-sharded submodules to the FSDP compute dtype, respecting _keep_in_fp32_modules(_strict). Removes the now-redundant gemma4-specific projector hook.
  • Pipeline dtype defaults to FSDP activation dtype (_dist_setup.py) instead of falling back incorrectly.
  • Qwen3.5 declares _keep_in_fp32_modules_strict for its _fp32_params holder.

Dormant helper (not wired)

  • resolve_storage_dtype() is added and unit-tested in precision_warnings.py but not called from any recipe. The four call sites carry a breadcrumb comment; wiring (the actual default-behavior change) lands in the follow-up PR.

Docs

  • Model-onboarding skill documents the _keep_in_fp32_modules_strict contract.

Test plan

  • tests/unit_tests/_transformers/test_auto_model.py
  • tests/unit_tests/distributed/test_fp32_compute_contract.py
  • tests/unit_tests/distributed/test_parallelizer_utils.py
  • tests/unit_tests/models/common/test_cast_model_to_dtype.py
  • tests/unit_tests/recipes/test_dist_setup.py
  • tests/unit_tests/components/training/test_precision_warnings.py
  • ruff check clean
  • GitLab CI (memory/perf sanity across touched models)

Nightly CI run: https://gitlab-master.nvidia.com/dl/JoC/nemo-ci/-/pipelines/53813531
Verified any failure cases are the same as main. No new failed tests.

Notes

  • Base branch is zpqiu/fp32-master-weights-custom-moe.
  • Follow-up PR (default fp32 master-weight behavior + EAGLE/diffusion configs + docs) will rebase onto this once merged.

@copy-pr-bot

copy-pr-bot Bot commented Jun 4, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@yuhezhang-ai

Copy link
Copy Markdown
Contributor Author

/ok to test 6f628c0

compute_dtype = getattr(getattr(model_wrapper, "mp_policy", None), "param_dtype", None)
if compute_dtype is not None:
for mp in model.parts if hasattr(model, "parts") else [model]:
cast_frozen_modules_to_compute_dtype(mp, compute_dtype)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @yuhezhang-ai , does this handle multiple compute dtypes in the frozen part? Also i see that it iterates over hte model/model.parts, so if this explicitly processes the frozen part or it iterates over the whole model and skips the wrapped parts. I feel explicit is better, because we can't assume hard FSDP2 as the downstream wrapper.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Based my latest code:

It casts a frozen subtree to a single compute dtype (FSDP's param_dtype), with fp32 carve-outs for _keep_in_fp32_modules(_strict).

Iterate-and-skip vs. explicit / FSDP2 assumption: fair point. The pass walks the whole model and casts each fully-frozen submodule — buffers always, plain params always. The one assumption is the DTensor branch: a sharded param is skipped on the theory that the wrapper re-gathers and down-casts it to the compute dtype, which is true for FSDP2 but not guaranteed for every DTensor-producing wrapper. In practice it's scoped to FSDP today because the whole pass only runs when there's an FSDP mp_policy.param_dtype, so other wrappers don't hit this path.
But I agree explicit is cleaner: the robust fix is to handle the frozen tower explicitly (exclude it from sharding so it stays plain and gets cast directly), which may also help perf. I think that's out of scope for this PR and gemma4 may need a small redesign — see my later comments on gemma4.

if pixel_values is not None:
image_features = get_gemma4_image_features_with_projector_dtype(
self.model, pixel_values, image_position_ids=image_position_ids
image_features = self.model.get_image_features(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @yuhezhang-ai have you confirmed this working?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @akoumpa, thanks!

On gemma4 specifically: it was indeed problematic. I dug into this. The frozen vision tower is actually sharded (it lands in the root FSDP unit, since gemma4 runs with wrap_outer_model=True), so FSDP down-casts its params to bf16 at all-gather. The actual leak is its buffers: HF's Gemma4VisionModel keeps fp32 std_bias/std_scale buffers and FSDP mixed precision never casts buffers, so (hidden_states - std_bias) * std_scale promotes the bf16 activation back to fp32 right before the bf16 multimodal projector. That's the mismatch. The fix 0915fd1 casts frozen buffers to the compute dtype, which removes it (verified the rest of the path — pooler, projector norm — all cast back).
Now it's running correctly after my fix: https://wandb.ai/Nemo-automodel/yuhez_workspace/groups/gemma4-26b-a4b-dtype-verify-rerun/workspace?nw=nwuseryuhez

Follow-up (separate PR) if we want to exclude vision_tower from fsdp sharding : two things combine here, gemma4 runs with wrap_outer_model=True, and the MoE parallelizer's frozen-tower skip (moe/parallelizer.py apply_fsdp()- around line 331) only checks audio_tower/visual, not vision_tower (which is also nested at model.model.vision_tower). So the frozen tower slips past the skip and the outer wrap then sweeps it into the root unit (sharded). To truly exclude it we'd need to fix both: add vision_tower to the skip and reconsider wrap_outer_model here (I'm not sure flipping it is safe / what else it affects). That's really a question for who owns the gemma4 model design. For now the buffer cast fixes correctness without touching the sharding layout.

Comment thread nemo_automodel/components/models/common/utils.py Outdated
@yuhezhang-ai yuhezhang-ai marked this pull request as ready for review June 5, 2026 19:47
Signed-off-by: Yuhe Zhang <yuhez@nvidia.com>
@yuhezhang-ai

Copy link
Copy Markdown
Contributor Author

/ok to test f3bc80f

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants