fix(mlx): warn on bf16 -> fp16 downcast in FastMLXModel loader#670
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a warning when downcasting bfloat16 weights to float16 in the MLX loader to prevent silent precision loss, specifically addressing issues observed with models like Gemma3. A new test suite has been added to verify the warning logic across various casting scenarios. Feedback identifies an optimization opportunity in the _convert_mlx_dtype function to restore an early exit in the parameter iteration loop, ensuring the code remains efficient while still detecting the conditions necessary to trigger the warning.
| for k, v in tree_flatten(model.parameters()): | ||
| if cast_pred(k) and mx.issubdtype(v.dtype, mx.floating) and v.dtype != target_dtype: | ||
| needs_cast = True | ||
| break | ||
| if v.dtype == mx.bfloat16: | ||
| has_bf16 = True |
There was a problem hiding this comment.
The current implementation removes the early break that was present in the original code, causing the loop to always iterate over all model parameters. While the number of parameters is usually manageable, we can restore the early exit optimization: if we are not casting to float16, we only need to know if any cast is required; if we are casting to float16, we can stop as soon as we find a bfloat16 weight (which triggers the warning).
for k, v in tree_flatten(model.parameters()):
if cast_pred(k) and mx.issubdtype(v.dtype, mx.floating) and v.dtype != target_dtype:
needs_cast = True
if v.dtype == mx.bfloat16:
has_bf16 = True
# Optimization: break early if we have enough information.
# If target is not fp16, we don't care about has_bf16.
# If target is fp16, we stop once we find a bf16 weight.
if has_bf16 or target_dtype != mx.float16:
breakThere was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: d9afcea691
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| "gemma3text", # EmbeddingGemma / standalone text-only Gemma3 | ||
| "gemma3n", | ||
| "gpt_oss", | ||
| "qwen3_5", # Qwen3.5 GDN layers NaN on fp16 |
There was a problem hiding this comment.
Include Qwen3.5 MoE in float32 warning gate
For Qwen3.5 MoE loads, FastMLXModel.from_pretrained passes the config model_type through to _convert_mlx_dtype, and this repo already treats qwen3_5_moe as a supported Qwen3.5 architecture in unsloth_zoo/mlx/compile.py. Because _is_force_float32_arch does an exact normalized match against this list, qwen3_5_moe will not match the lone qwen3_5 entry here, so bf16→fp16 downcasts of those GDN-based models skip the warning this change is adding.
Useful? React with 👍 / 👎.
Update: gate warning on FORCE_FLOAT32 + centralize the listPer @danielhanchen — the unconditional warning fired for every bf16→fp16 cast, which is too noisy. Now:
Tests updated: 9 cases covering the export, all FORCE_FLOAT32 archs warning, non-listed arch staying silent, upcasts/no-op casts staying silent, and the cast still happening after the warning. All pass locally (`pytest tests/test_mlx_dtype_downcast_warning.py` → 9/9). |
danielhanchen
left a comment
There was a problem hiding this comment.
Thank you for the PR! The goal of this PR is to warn callers when FastMLXModel silently downcasts a bfloat16-native model to fp16 (which NaN/Infs on Gemma3 family / gpt_oss / Qwen3.5), and to give the CUDA loader and the MLX loader a shared source of truth for that list. As a summary, this PR adds a dependency-free unsloth_zoo/model_lists.py carrying FORCE_FLOAT32, re-exports it from unsloth_zoo (top-level) and from unsloth_zoo.compiler (back-compat), introduces a _is_force_float32_arch(model_type) helper that normalizes -/_ and respects trailing-comma entries, and gates the new bf16→fp16 warning on that lookup. The companion PR #5610 switches unsloth/models/loader.py to import from there.
Two independent Opus reviewers were run in parallel on this PR.
| Reviewers | Severity | Finding |
|---|---|---|
| 2/2 | Med | The gemma3, trailing-comma marker is documented as an exact-match delimiter, but _is_force_float32_arch already strips it; the comma is a no-op in this new helper. Either drop the comma and update the comment, or document that the comma only matters for the CUDA loader's substring path. |
| 1/2 | Med | Warning message names dtype=None and dtype='float32' but omits dtype='bfloat16' and the UNSLOTH_FORCE_FLOAT32 env-var alternative. |
| 2/2 | Nit | The local from ..model_lists import FORCE_FLOAT32 at the top of _convert_mlx_dtype is unused; only _is_force_float32_arch consumes the list. Dead import. |
| 1/2 | Nit | The break after needs_cast = True was removed so the loop can also detect has_bf16. On large models this becomes O(n_params) even when the warning can never fire (e.g. fp32→fp16). Add a fast-exit. |
| 1/2 | Low | _norm does not strip ., so a future HF config emitting model_type="qwen3.5" (dot form) would silently miss the warning. |
| 1/2 | Nit | test_gemma3_comma_does_not_match_gemma3n is misnamed — it actually asserts that gemma3n DOES match (via its own list entry). |
Overall: APPROVE_WITH_NITS.
See inline comments for details and suggested fixes.
| # must run in bf16 or fp32. Loaded as float16 they silently NaN/Inf at | ||
| # training time. Shared source of truth for the CUDA loader | ||
| # (unsloth/models/loader.py) and the MLX loader (unsloth_zoo/mlx/loader.py). | ||
| FORCE_FLOAT32 = [ |
There was a problem hiding this comment.
[2/2 reviewers] Med: the trailing-comma marker is a substring-delimiter trick used by unsloth/models/loader.py:1378 (where matching is disable_name.lower() in model_types_all). _is_force_float32_arch in this PR explicitly strips the comma before comparing, so the marker is a no-op in zoo's matcher. Either drop the comma here (and rely on gemma3 vs gemma3n already being separate exact-match entries) or update the comment to make clear it's only load-bearing for the substring-matching CUDA consumer.
| FORCE_FLOAT32 = [ | |
| FORCE_FLOAT32 = [ | |
| "gemma3", # exact-match in zoo; trailing-comma kept for back-compat with unsloth/models/loader.py substring matcher (do not remove without coordinating that file). | |
| "gemma3text", # EmbeddingGemma / standalone text-only Gemma3 | |
| "gemma3n", | |
| "gpt_oss", | |
| "qwen3_5", # Qwen3.5 GDN layers NaN on fp16 | |
| ] |
(or, alternatively, keep "gemma3," literally and rewrite the docstring above to say the comma is a CUDA-side substring delimiter, not an exact-match marker.)
| """ | ||
| import mlx.core as mx | ||
| from mlx.utils import tree_flatten, tree_map_with_path | ||
| from ..model_lists import FORCE_FLOAT32 |
There was a problem hiding this comment.
[2/2 reviewers] Nit: this from ..model_lists import FORCE_FLOAT32 is dead — FORCE_FLOAT32 is never referenced inside _convert_mlx_dtype. The consumption happens inside _is_force_float32_arch. Drop it.
| from ..model_lists import FORCE_FLOAT32 | |
| from mlx.utils import tree_flatten, tree_map_with_path |
| needs_cast = True | ||
| break | ||
| if v.dtype == mx.bfloat16: | ||
| has_bf16 = True |
There was a problem hiding this comment.
[1/2 reviewers] Nit: break was removed from the original loop so we can also detect bf16 presence. But this means every call now scans all parameters even on the common case where no warning can ever fire (target_dtype == fp16 but no bf16 weights, OR model_type not in FORCE_FLOAT32). For large models this is non-trivial. Short-circuit when both flags are set:
| has_bf16 = True | |
| for k, v in tree_flatten(model.parameters()): | |
| if cast_pred(k) and mx.issubdtype(v.dtype, mx.floating) and v.dtype != target_dtype: | |
| needs_cast = True | |
| if v.dtype == mx.bfloat16: | |
| has_bf16 = True | |
| if has_bf16: | |
| break |
| return | ||
|
|
||
| if has_bf16 and target_dtype == mx.float16 and _is_force_float32_arch(model_type): | ||
| warnings.warn( |
There was a problem hiding this comment.
[1/2 reviewers] Med: the warning text gives dtype=None and dtype='float32' but is missing dtype='bfloat16' (explicit-bf16 keeps native and is the typical answer on M3+) and a pointer to UNSLOTH_FORCE_FLOAT32=1 for users who want the CUDA-style guard. Make it actionable:
| warnings.warn( | |
| warnings.warn( | |
| f"Unsloth: downcasting bfloat16 -> float16 on {model_type!r}, " | |
| "which is known to NaN/Inf in fp16. Pass dtype=None to keep " | |
| "native bf16, dtype='bfloat16' to be explicit on M3+, or " | |
| "dtype='float32' for full precision. Set UNSLOTH_FORCE_FLOAT32=1 " | |
| "to silence this warning if the downcast is intentional.", | |
| stacklevel=2, | |
| ) |
| if not model_type: | ||
| return False | ||
| from ..model_lists import FORCE_FLOAT32 | ||
| def _norm(s: str) -> str: |
There was a problem hiding this comment.
[1/2 reviewers] Low: _norm strips - and _ but not .. HF emits model_type values like qwen3.5 in some configs; that variant would silently miss the warning even though qwen3_5 is in FORCE_FLOAT32. Cheap to harden:
| def _norm(s: str) -> str: | |
| def _norm(s: str) -> str: | |
| return s.lower().replace("-", "").replace("_", "").replace(".", "") |
| # 'gemma3,' entry doesn't accidentally swallow gemma3n variants by | ||
| # prefix match. gemma3n itself still matches via its own entry. | ||
| assert _is_force_float32_arch("gemma3") is True | ||
| assert _is_force_float32_arch("gemma3n") is True |
There was a problem hiding this comment.
[1/2 reviewers] Nit: the test name says gemma3, does not match gemma3n but the body actually asserts that gemma3n DOES match (via its own list entry, not via gemma3,). The actual no-false-positive case is gemma3_audio_only_pretend. Rename for clarity:
| assert _is_force_float32_arch("gemma3n") is True | |
| def test_gemma3_exact_match_does_not_swallow_unrelated_gemma3_variants(): | |
| """`gemma3` and `gemma3n` are separate exact-match entries; an | |
| invented gemma3 variant not in the list should NOT match.""" | |
| import torch | |
| from unsloth_zoo.mlx.loader import _is_force_float32_arch | |
| assert _is_force_float32_arch("gemma3") is True | |
| assert _is_force_float32_arch("gemma3n") is True # own entry | |
| assert _is_force_float32_arch("gemma3text") is True # own entry | |
| assert _is_force_float32_arch("gemma3_audio_only_pretend") is False |
|
Doc + test update (commit ae74d0e): make explicit that
Added Tests: |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: ae74d0e9c3
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| "which is known to NaN/Inf in fp16. Pass dtype=None to keep " | ||
| "native bf16, or dtype='float32' for full precision.", |
There was a problem hiding this comment.
Avoid suggesting dtype=None when it already selected fp16
On Apple M1/M2 the default from_pretrained(dtype=None) path resolves target_dtype to mx.float16 before calling this helper, so a bf16 Gemma/GPT-OSS/Qwen load will emit this warning even though the user already passed (or omitted) dtype=None. In that scenario the remediation text is misleading because dtype=None does not keep native bf16 on those chips; users would need dtype='float32' or a chip-specific message/explicit-dtype gate.
Useful? React with 👍 / 👎.
`_convert_mlx_dtype` silently downcasts native bf16 weights to fp16 when the user passes `dtype="float16"`. fp16's finite range (~6.5e4) is much narrower than bf16's (~3.4e38); models with large activations (e.g. Gemma3-270m) can lose precision or overflow silently. Empirically (gemma-3-270m-it single-row LoRA memorization, n=15 seeds): - FastMLXModel(dtype=None) + last-16 layers: 47% greedy-decode pass rate - FastMLXModel(dtype="float16") + last-16 layers: 15% The 32pp drop is from the silent bf16 -> fp16 cast (`probe_32` vs `probe_34` in danielhanchen/unsloth-staging-2). Teacher-forced completion loss is 0 in both cases (memorization works), so CI smoke gating per unslothai/unsloth#5537 stays green either way — but the greedy-decode behavior diverges noticeably. This patch only adds a warning. The cast still happens (users on M1/M2 without native bf16 GPU support genuinely need fp16). The warning surfaces the trade-off so callers can switch to dtype=None / "bfloat16" on M3+ if they didn't intend to downcast. Tests: - test_mlx_dtype_downcast_warning.py — five cases: bf16->fp16 warns; bf16->fp32 / fp32->fp16 / no-cast do NOT emit the warning; cast still occurs after the warning.
Per code-comment policy: keep WHY (range narrowing risk), drop the empirical numbers and probe references — those live in the commit message of 0987d27.
Move FORCE_FLOAT32 — the list of architectures whose activations exceed fp16's finite range — into a new dependency-free module unsloth_zoo/model_lists.py and re-export from both unsloth_zoo (top-level) and unsloth_zoo.compiler (back-compat). unsloth/models/loader.py can now 'from unsloth_zoo import FORCE_FLOAT32' and drop its local copy. Gate _convert_mlx_dtype's bf16->fp16 downcast warning on the model_type being in FORCE_FLOAT32. Llama/Mistral/Qwen2 etc. cast silently as before; only models that actually NaN/Inf in fp16 (Gemma3 family, gpt_oss, Qwen3.5) get the warning. _is_force_float32_arch normalizes -/_ and honors the 'gemma3,' trailing-comma exact-match marker.
Per maintainer feedback: the FORCE_FLOAT32 entries are HuggingFace config.json model_type values (the same strings returned by unsloth_zoo.hf_utils.get_transformers_model_type). Make that contract explicit in the module docstring with worked examples for each entry, and add a parity test that pins _is_force_float32_arch against the real-world model_type strings on the Hub.
ae74d0e to
c3af3b2
Compare
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: c3af3b23ee
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
|
|
||
| def test_force_float32_list_exported(): | ||
| """FORCE_FLOAT32 is importable from the top-level unsloth_zoo namespace.""" | ||
| import unsloth_zoo |
There was a problem hiding this comment.
Avoid importing the top-level package in this unit test
When this new test module is run in the repo's CPU-only harness without the separate unsloth package installed, this import executes unsloth_zoo.__init__ and hits its find_spec("unsloth") guard before the assertion runs; I reproduced this with pytest tests/test_mlx_dtype_downcast_warning.py -q where this test fails while the submodule-based tests pass. Importing the dependency-free unsloth_zoo.model_lists module directly would keep the test exercising the new list without requiring a full Unsloth install.
Useful? React with 👍 / 👎.
Summary
_convert_mlx_dtypesilently downcasts native bf16 weights to fp16 when the user passesdtype="float16"toFastMLXModel.from_pretrained.Why
Empirical (gemma-3-270m-it, single-row LoRA memorization, n=15 seeds, otherwise-identical setup):
None(keeps native bf16)"float16"(silent bf16 -> fp16)The 32pp drop is entirely from the silent bf16 -> fp16 cast. Teacher-forced completion loss is
0in both cases — the model memorizes either way; only the first-token greedy argmax distribution diverges. CI smoke gating perunslothai/unsloth#5537stays green either way, but greedy-decode behavior diverges noticeably enough that a user comparing fp16 vs bf16 runs would suspect a different bug.Tracked alongside two earlier MLX-parity PRs:
unslothai/unsloth-zoo#669—finetune_last_n_layersknob (layer-selection mismatch).unslothai/unsloth#5564— same knob, CUDA path.This PR addresses factor (4) of the four-factor bisection ($\Delta$ pass rate vs
mlx_lm.load + manual loop + last-16 layers):unsloth-zoo#669)MLXTraineroverhead vs manual loop (-14pp)FastMLXModelloader patches (-10pp)Behavior
dtype="float16"against a bf16-native model: cast still happens, warning logged viawarnings.warn(...).dtype="bfloat16"/dtype=None/dtype="float32": no warning, no change.dtype="float16"against a fp32-native or fp16-native model: no warning (the bf16 -> fp16 specific regression doesn't apply).The warning message points the user at
dtype=Noneon bf16-capable Apple Silicon (M3+) anddtype="float32"for full precision.Test plan
tests/test_mlx_dtype_downcast_warning.py— five cases:bf16 -> fp16emits the warning.bf16 -> fp32(upcast) does NOT emit it.fp32 -> fp16(different lossy regime) does NOT emit it.model.parameters()after the warning.pytest tests/test_mlx_dtype_downcast_warning.py -v-> 5 passed.