fix(precision): dtype contract bug fixes for FSDP2 mixed-dtype loads#2419
Conversation
|
/ok to test 6f628c0 |
| compute_dtype = getattr(getattr(model_wrapper, "mp_policy", None), "param_dtype", None) | ||
| if compute_dtype is not None: | ||
| for mp in model.parts if hasattr(model, "parts") else [model]: | ||
| cast_frozen_modules_to_compute_dtype(mp, compute_dtype) |
There was a problem hiding this comment.
Hi @yuhezhang-ai , does this handle multiple compute dtypes in the frozen part? Also i see that it iterates over hte model/model.parts, so if this explicitly processes the frozen part or it iterates over the whole model and skips the wrapped parts. I feel explicit is better, because we can't assume hard FSDP2 as the downstream wrapper.
There was a problem hiding this comment.
Thanks. Based my latest code:
It casts a frozen subtree to a single compute dtype (FSDP's param_dtype), with fp32 carve-outs for _keep_in_fp32_modules(_strict).
Iterate-and-skip vs. explicit / FSDP2 assumption: fair point. The pass walks the whole model and casts each fully-frozen submodule — buffers always, plain params always. The one assumption is the DTensor branch: a sharded param is skipped on the theory that the wrapper re-gathers and down-casts it to the compute dtype, which is true for FSDP2 but not guaranteed for every DTensor-producing wrapper. In practice it's scoped to FSDP today because the whole pass only runs when there's an FSDP mp_policy.param_dtype, so other wrappers don't hit this path.
But I agree explicit is cleaner: the robust fix is to handle the frozen tower explicitly (exclude it from sharding so it stays plain and gets cast directly), which may also help perf. I think that's out of scope for this PR and gemma4 may need a small redesign — see my later comments on gemma4.
| if pixel_values is not None: | ||
| image_features = get_gemma4_image_features_with_projector_dtype( | ||
| self.model, pixel_values, image_position_ids=image_position_ids | ||
| image_features = self.model.get_image_features( |
There was a problem hiding this comment.
Hi @yuhezhang-ai have you confirmed this working?
There was a problem hiding this comment.
Hi @akoumpa, thanks!
On gemma4 specifically: it was indeed problematic. I dug into this. The frozen vision tower is actually sharded (it lands in the root FSDP unit, since gemma4 runs with wrap_outer_model=True), so FSDP down-casts its params to bf16 at all-gather. The actual leak is its buffers: HF's Gemma4VisionModel keeps fp32 std_bias/std_scale buffers and FSDP mixed precision never casts buffers, so (hidden_states - std_bias) * std_scale promotes the bf16 activation back to fp32 right before the bf16 multimodal projector. That's the mismatch. The fix 0915fd1 casts frozen buffers to the compute dtype, which removes it (verified the rest of the path — pooler, projector norm — all cast back).
Now it's running correctly after my fix: https://wandb.ai/Nemo-automodel/yuhez_workspace/groups/gemma4-26b-a4b-dtype-verify-rerun/workspace?nw=nwuseryuhez
Follow-up (separate PR) if we want to exclude vision_tower from fsdp sharding : two things combine here, gemma4 runs with wrap_outer_model=True, and the MoE parallelizer's frozen-tower skip (moe/parallelizer.py apply_fsdp()- around line 331) only checks audio_tower/visual, not vision_tower (which is also nested at model.model.vision_tower). So the frozen tower slips past the skip and the outer wrap then sweeps it into the root unit (sharded). To truly exclude it we'd need to fix both: add vision_tower to the skip and reconsider wrap_outer_model here (I'm not sure flipping it is safe / what else it affects). That's really a question for who owns the gemma4 model design. For now the buffer cast fixes correctness without touching the sharding layout.
Signed-off-by: Yuhe Zhang <yuhez@nvidia.com>
|
/ok to test f3bc80f |
Summary
Split out of #2379 (the broader fp32-master-weight effort) so the bug-fix portion can merge soon, separate from the default-dtype behavior change (which needs broad OOM/config validation and will follow up).
This PR contains only changes that are no-ops for existing example configs (no config sets
model.torch_dtype: float32, so there is no memory/perf regression). It fixes correctness issues in how dtypes are loaded, propagated, and sharded.Bug fixes / refactors
_restore_loaded_model_dtypeis now dtype-aware (model_init.py): unifies each floating tensor topromote_types(checkpoint, requested)instead of always restoring the checkpoint dtype. Honors an explicit fp32 request while preserving intrinsically-fp32 checkpoint params (e.g.A_log) under bf16; no-op for the bf16/auto path. Fixes FSDP2 uniform-dtype tripping on HF mixed-dtype loads.parallelizer_utils.py,parallelizer.py):fully_shard_by_dtypegroups params by resolved compute dtype with clear precedence (pinned fp32 → HF-recorded_hf_compute_dtype→ fallback).models/common/utils.py,infrastructure.py): generalcast_frozen_modules_to_compute_dtypecasts fully-frozen, non-sharded submodules to the FSDP compute dtype, respecting_keep_in_fp32_modules(_strict). Removes the now-redundant gemma4-specific projector hook._dist_setup.py) instead of falling back incorrectly._keep_in_fp32_modules_strictfor its_fp32_paramsholder.Dormant helper (not wired)
resolve_storage_dtype()is added and unit-tested inprecision_warnings.pybut not called from any recipe. The four call sites carry a breadcrumb comment; wiring (the actual default-behavior change) lands in the follow-up PR.Docs
_keep_in_fp32_modules_strictcontract.Test plan
tests/unit_tests/_transformers/test_auto_model.pytests/unit_tests/distributed/test_fp32_compute_contract.pytests/unit_tests/distributed/test_parallelizer_utils.pytests/unit_tests/models/common/test_cast_model_to_dtype.pytests/unit_tests/recipes/test_dist_setup.pytests/unit_tests/components/training/test_precision_warnings.pyruff checkcleanNightly CI run: https://gitlab-master.nvidia.com/dl/JoC/nemo-ci/-/pipelines/53813531
Verified any failure cases are the same as main. No new failed tests.
Notes
zpqiu/fp32-master-weights-custom-moe.