Skip to content

fix(mlx): warn on bf16 -> fp16 downcast in FastMLXModel loader#670

Merged
danielhanchen merged 4 commits into
mainfrom
warn-bf16-fp16-downcast
May 19, 2026
Merged

fix(mlx): warn on bf16 -> fp16 downcast in FastMLXModel loader#670
danielhanchen merged 4 commits into
mainfrom
warn-bf16-fp16-downcast

Conversation

@danielhanchen

Copy link
Copy Markdown
Member

Summary

  • _convert_mlx_dtype silently downcasts native bf16 weights to fp16 when the user passes dtype="float16" to FastMLXModel.from_pretrained.
  • fp16's finite range (~6.5e4) is much narrower than bf16's (~3.4e38); models with large activations (e.g. Gemma3) can lose precision or overflow silently.
  • This PR adds a one-line warning when that downcast is about to happen; the cast still runs (users on M1/M2 without native bf16 GPU support need fp16).

Why

Empirical (gemma-3-270m-it, single-row LoRA memorization, n=15 seeds, otherwise-identical setup):

dtype= greedy-decode pass cf_loss=0 (15 seeds)
None (keeps native bf16) 47% 15/15
"float16" (silent bf16 -> fp16) 15% 15/15

The 32pp drop is entirely from the silent bf16 -> fp16 cast. Teacher-forced completion loss is 0 in both cases — the model memorizes either way; only the first-token greedy argmax distribution diverges. CI smoke gating per unslothai/unsloth#5537 stays green either way, but greedy-decode behavior diverges noticeably enough that a user comparing fp16 vs bf16 runs would suspect a different bug.

Tracked alongside two earlier MLX-parity PRs:

  • unslothai/unsloth-zoo#669finetune_last_n_layers knob (layer-selection mismatch).
  • unslothai/unsloth#5564 — same knob, CUDA path.

This PR addresses factor (4) of the four-factor bisection ($\Delta$ pass rate vs mlx_lm.load + manual loop + last-16 layers):

  1. Layer selection (unsloth-zoo#669)
  2. MLXTrainer overhead vs manual loop (-14pp)
  3. FastMLXModel loader patches (-10pp)
  4. bf16 -> fp16 downcast (-28pp) — this PR

Behavior

  • dtype="float16" against a bf16-native model: cast still happens, warning logged via warnings.warn(...).
  • dtype="bfloat16" / dtype=None / dtype="float32": no warning, no change.
  • dtype="float16" against a fp32-native or fp16-native model: no warning (the bf16 -> fp16 specific regression doesn't apply).

The warning message points the user at dtype=None on bf16-capable Apple Silicon (M3+) and dtype="float32" for full precision.

Test plan

  • tests/test_mlx_dtype_downcast_warning.py — five cases:
    • bf16 -> fp16 emits the warning.
    • bf16 -> fp32 (upcast) does NOT emit it.
    • fp32 -> fp16 (different lossy regime) does NOT emit it.
    • No-cast-needed early return does NOT emit it.
    • Cast still mutates model.parameters() after the warning.
  • Local: pytest tests/test_mlx_dtype_downcast_warning.py -v -> 5 passed.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a warning when downcasting bfloat16 weights to float16 in the MLX loader to prevent silent precision loss, specifically addressing issues observed with models like Gemma3. A new test suite has been added to verify the warning logic across various casting scenarios. Feedback identifies an optimization opportunity in the _convert_mlx_dtype function to restore an early exit in the parameter iteration loop, ensuring the code remains efficient while still detecting the conditions necessary to trigger the warning.

Comment thread unsloth_zoo/mlx/loader.py
Comment on lines 110 to +114
for k, v in tree_flatten(model.parameters()):
if cast_pred(k) and mx.issubdtype(v.dtype, mx.floating) and v.dtype != target_dtype:
needs_cast = True
break
if v.dtype == mx.bfloat16:
has_bf16 = True

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation removes the early break that was present in the original code, causing the loop to always iterate over all model parameters. While the number of parameters is usually manageable, we can restore the early exit optimization: if we are not casting to float16, we only need to know if any cast is required; if we are casting to float16, we can stop as soon as we find a bfloat16 weight (which triggers the warning).

    for k, v in tree_flatten(model.parameters()):
        if cast_pred(k) and mx.issubdtype(v.dtype, mx.floating) and v.dtype != target_dtype:
            needs_cast = True
            if v.dtype == mx.bfloat16:
                has_bf16 = True
            # Optimization: break early if we have enough information.
            # If target is not fp16, we don't care about has_bf16.
            # If target is fp16, we stop once we find a bf16 weight.
            if has_bf16 or target_dtype != mx.float16:
                break

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: d9afcea691

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

"gemma3text", # EmbeddingGemma / standalone text-only Gemma3
"gemma3n",
"gpt_oss",
"qwen3_5", # Qwen3.5 GDN layers NaN on fp16

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Include Qwen3.5 MoE in float32 warning gate

For Qwen3.5 MoE loads, FastMLXModel.from_pretrained passes the config model_type through to _convert_mlx_dtype, and this repo already treats qwen3_5_moe as a supported Qwen3.5 architecture in unsloth_zoo/mlx/compile.py. Because _is_force_float32_arch does an exact normalized match against this list, qwen3_5_moe will not match the lone qwen3_5 entry here, so bf16→fp16 downcasts of those GDN-based models skip the warning this change is adding.

Useful? React with 👍 / 👎.

@danielhanchen

Copy link
Copy Markdown
Member Author

Update: gate warning on FORCE_FLOAT32 + centralize the list

Per @danielhanchen — the unconditional warning fired for every bf16→fp16 cast, which is too noisy. Now:

  1. The bf16→fp16 warning is gated on `model_type` being in the `FORCE_FLOAT32` list (Gemma3 family, gpt\_oss, Qwen3.5). Llama / Mistral / Qwen2 / etc. cast silently as before.

  2. `FORCE_FLOAT32` was previously defined inline in `unsloth/models/loader.py`. It now lives in a new dependency-free module `unsloth_zoo/model_lists.py` and is re-exported as `unsloth_zoo.FORCE_FLOAT32`. The new helper `is_force_float32_arch` normalizes `-`/`` and respects the trailing-comma exact-match marker (`gemma3,`).

  3. Companion unsloth PR #5610 switches `unsloth/models/loader.py` to `from unsloth_zoo import FORCE_FLOAT32` and deletes its local copy. Verified `unsloth.models.loader.FORCE_FLOAT32 is unsloth_zoo.FORCE_FLOAT32` after both PRs.

Tests updated: 9 cases covering the export, all FORCE_FLOAT32 archs warning, non-listed arch staying silent, upcasts/no-op casts staying silent, and the cast still happening after the warning. All pass locally (`pytest tests/test_mlx_dtype_downcast_warning.py` → 9/9).

@danielhanchen danielhanchen left a comment

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR! The goal of this PR is to warn callers when FastMLXModel silently downcasts a bfloat16-native model to fp16 (which NaN/Infs on Gemma3 family / gpt_oss / Qwen3.5), and to give the CUDA loader and the MLX loader a shared source of truth for that list. As a summary, this PR adds a dependency-free unsloth_zoo/model_lists.py carrying FORCE_FLOAT32, re-exports it from unsloth_zoo (top-level) and from unsloth_zoo.compiler (back-compat), introduces a _is_force_float32_arch(model_type) helper that normalizes -/_ and respects trailing-comma entries, and gates the new bf16→fp16 warning on that lookup. The companion PR #5610 switches unsloth/models/loader.py to import from there.

Two independent Opus reviewers were run in parallel on this PR.

Reviewers Severity Finding
2/2 Med The gemma3, trailing-comma marker is documented as an exact-match delimiter, but _is_force_float32_arch already strips it; the comma is a no-op in this new helper. Either drop the comma and update the comment, or document that the comma only matters for the CUDA loader's substring path.
1/2 Med Warning message names dtype=None and dtype='float32' but omits dtype='bfloat16' and the UNSLOTH_FORCE_FLOAT32 env-var alternative.
2/2 Nit The local from ..model_lists import FORCE_FLOAT32 at the top of _convert_mlx_dtype is unused; only _is_force_float32_arch consumes the list. Dead import.
1/2 Nit The break after needs_cast = True was removed so the loop can also detect has_bf16. On large models this becomes O(n_params) even when the warning can never fire (e.g. fp32→fp16). Add a fast-exit.
1/2 Low _norm does not strip ., so a future HF config emitting model_type="qwen3.5" (dot form) would silently miss the warning.
1/2 Nit test_gemma3_comma_does_not_match_gemma3n is misnamed — it actually asserts that gemma3n DOES match (via its own list entry).

Overall: APPROVE_WITH_NITS.

See inline comments for details and suggested fixes.

# must run in bf16 or fp32. Loaded as float16 they silently NaN/Inf at
# training time. Shared source of truth for the CUDA loader
# (unsloth/models/loader.py) and the MLX loader (unsloth_zoo/mlx/loader.py).
FORCE_FLOAT32 = [

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[2/2 reviewers] Med: the trailing-comma marker is a substring-delimiter trick used by unsloth/models/loader.py:1378 (where matching is disable_name.lower() in model_types_all). _is_force_float32_arch in this PR explicitly strips the comma before comparing, so the marker is a no-op in zoo's matcher. Either drop the comma here (and rely on gemma3 vs gemma3n already being separate exact-match entries) or update the comment to make clear it's only load-bearing for the substring-matching CUDA consumer.

Suggested change
FORCE_FLOAT32 = [
FORCE_FLOAT32 = [
"gemma3", # exact-match in zoo; trailing-comma kept for back-compat with unsloth/models/loader.py substring matcher (do not remove without coordinating that file).
"gemma3text", # EmbeddingGemma / standalone text-only Gemma3
"gemma3n",
"gpt_oss",
"qwen3_5", # Qwen3.5 GDN layers NaN on fp16
]

(or, alternatively, keep "gemma3," literally and rewrite the docstring above to say the comma is a CUDA-side substring delimiter, not an exact-match marker.)

Comment thread unsloth_zoo/mlx/loader.py
"""
import mlx.core as mx
from mlx.utils import tree_flatten, tree_map_with_path
from ..model_lists import FORCE_FLOAT32

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[2/2 reviewers] Nit: this from ..model_lists import FORCE_FLOAT32 is dead — FORCE_FLOAT32 is never referenced inside _convert_mlx_dtype. The consumption happens inside _is_force_float32_arch. Drop it.

Suggested change
from ..model_lists import FORCE_FLOAT32
from mlx.utils import tree_flatten, tree_map_with_path

Comment thread unsloth_zoo/mlx/loader.py
needs_cast = True
break
if v.dtype == mx.bfloat16:
has_bf16 = True

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[1/2 reviewers] Nit: break was removed from the original loop so we can also detect bf16 presence. But this means every call now scans all parameters even on the common case where no warning can ever fire (target_dtype == fp16 but no bf16 weights, OR model_type not in FORCE_FLOAT32). For large models this is non-trivial. Short-circuit when both flags are set:

Suggested change
has_bf16 = True
for k, v in tree_flatten(model.parameters()):
if cast_pred(k) and mx.issubdtype(v.dtype, mx.floating) and v.dtype != target_dtype:
needs_cast = True
if v.dtype == mx.bfloat16:
has_bf16 = True
if has_bf16:
break

Comment thread unsloth_zoo/mlx/loader.py
return

if has_bf16 and target_dtype == mx.float16 and _is_force_float32_arch(model_type):
warnings.warn(

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[1/2 reviewers] Med: the warning text gives dtype=None and dtype='float32' but is missing dtype='bfloat16' (explicit-bf16 keeps native and is the typical answer on M3+) and a pointer to UNSLOTH_FORCE_FLOAT32=1 for users who want the CUDA-style guard. Make it actionable:

Suggested change
warnings.warn(
warnings.warn(
f"Unsloth: downcasting bfloat16 -> float16 on {model_type!r}, "
"which is known to NaN/Inf in fp16. Pass dtype=None to keep "
"native bf16, dtype='bfloat16' to be explicit on M3+, or "
"dtype='float32' for full precision. Set UNSLOTH_FORCE_FLOAT32=1 "
"to silence this warning if the downcast is intentional.",
stacklevel=2,
)

Comment thread unsloth_zoo/mlx/loader.py
if not model_type:
return False
from ..model_lists import FORCE_FLOAT32
def _norm(s: str) -> str:

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[1/2 reviewers] Low: _norm strips - and _ but not .. HF emits model_type values like qwen3.5 in some configs; that variant would silently miss the warning even though qwen3_5 is in FORCE_FLOAT32. Cheap to harden:

Suggested change
def _norm(s: str) -> str:
def _norm(s: str) -> str:
return s.lower().replace("-", "").replace("_", "").replace(".", "")

# 'gemma3,' entry doesn't accidentally swallow gemma3n variants by
# prefix match. gemma3n itself still matches via its own entry.
assert _is_force_float32_arch("gemma3") is True
assert _is_force_float32_arch("gemma3n") is True

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[1/2 reviewers] Nit: the test name says gemma3, does not match gemma3n but the body actually asserts that gemma3n DOES match (via its own list entry, not via gemma3,). The actual no-false-positive case is gemma3_audio_only_pretend. Rename for clarity:

Suggested change
assert _is_force_float32_arch("gemma3n") is True
def test_gemma3_exact_match_does_not_swallow_unrelated_gemma3_variants():
"""`gemma3` and `gemma3n` are separate exact-match entries; an
invented gemma3 variant not in the list should NOT match."""
import torch
from unsloth_zoo.mlx.loader import _is_force_float32_arch
assert _is_force_float32_arch("gemma3") is True
assert _is_force_float32_arch("gemma3n") is True # own entry
assert _is_force_float32_arch("gemma3text") is True # own entry
assert _is_force_float32_arch("gemma3_audio_only_pretend") is False

@danielhanchen

Copy link
Copy Markdown
Member Author

Doc + test update (commit ae74d0e): make explicit that FORCE_FLOAT32 entries are HuggingFace config.json model_type strings (the same values returned by unsloth_zoo.hf_utils.get_transformers_model_type).

unsloth_zoo/model_lists.py now documents each entry against its real-world model_type:

Hub family config.json model_type FORCE_FLOAT32 entry
google/gemma-3-* "gemma3" "gemma3," (substring delimiter for the CUDA loader)
google/embeddinggemma-* "gemma3_text" "gemma3text" (strips _ in MLX matcher)
google/gemma-3n-* "gemma3n" "gemma3n"
openai/gpt-oss-* "gpt_oss" "gpt_oss"
Qwen/Qwen3.5-* "qwen3_5" "qwen3_5"

Added test_force_float32_matches_config_json_model_types which asserts _is_force_float32_arch returns True for the five real-world model_type strings above, so a future entry-rename can't silently miss a real Hub model.

Tests: pytest tests/test_mlx_dtype_downcast_warning.py -> 10/10 pass.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: ae74d0e9c3

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread unsloth_zoo/mlx/loader.py
Comment on lines +136 to +137
"which is known to NaN/Inf in fp16. Pass dtype=None to keep "
"native bf16, or dtype='float32' for full precision.",

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Avoid suggesting dtype=None when it already selected fp16

On Apple M1/M2 the default from_pretrained(dtype=None) path resolves target_dtype to mx.float16 before calling this helper, so a bf16 Gemma/GPT-OSS/Qwen load will emit this warning even though the user already passed (or omitted) dtype=None. In that scenario the remediation text is misleading because dtype=None does not keep native bf16 on those chips; users would need dtype='float32' or a chip-specific message/explicit-dtype gate.

Useful? React with 👍 / 👎.

`_convert_mlx_dtype` silently downcasts native bf16 weights to fp16
when the user passes `dtype="float16"`. fp16's finite range (~6.5e4)
is much narrower than bf16's (~3.4e38); models with large activations
(e.g. Gemma3-270m) can lose precision or overflow silently.

Empirically (gemma-3-270m-it single-row LoRA memorization, n=15 seeds):
- FastMLXModel(dtype=None) + last-16 layers: 47% greedy-decode pass rate
- FastMLXModel(dtype="float16") + last-16 layers: 15%

The 32pp drop is from the silent bf16 -> fp16 cast (`probe_32` vs
`probe_34` in danielhanchen/unsloth-staging-2). Teacher-forced
completion loss is 0 in both cases (memorization works), so CI smoke
gating per unslothai/unsloth#5537 stays green either way — but the
greedy-decode behavior diverges noticeably.

This patch only adds a warning. The cast still happens (users on
M1/M2 without native bf16 GPU support genuinely need fp16). The
warning surfaces the trade-off so callers can switch to dtype=None /
"bfloat16" on M3+ if they didn't intend to downcast.

Tests:
- test_mlx_dtype_downcast_warning.py — five cases: bf16->fp16 warns;
  bf16->fp32 / fp32->fp16 / no-cast do NOT emit the warning; cast
  still occurs after the warning.
Per code-comment policy: keep WHY (range narrowing risk), drop the
empirical numbers and probe references — those live in the commit
message of 0987d27.
Move FORCE_FLOAT32 — the list of architectures whose activations exceed
fp16's finite range — into a new dependency-free module
unsloth_zoo/model_lists.py and re-export from both unsloth_zoo (top-level)
and unsloth_zoo.compiler (back-compat). unsloth/models/loader.py can now
'from unsloth_zoo import FORCE_FLOAT32' and drop its local copy.

Gate _convert_mlx_dtype's bf16->fp16 downcast warning on the model_type
being in FORCE_FLOAT32. Llama/Mistral/Qwen2 etc. cast silently as before;
only models that actually NaN/Inf in fp16 (Gemma3 family, gpt_oss,
Qwen3.5) get the warning. _is_force_float32_arch normalizes -/_ and
honors the 'gemma3,' trailing-comma exact-match marker.
Per maintainer feedback: the FORCE_FLOAT32 entries are HuggingFace
config.json model_type values (the same strings returned by
unsloth_zoo.hf_utils.get_transformers_model_type). Make that contract
explicit in the module docstring with worked examples for each entry,
and add a parity test that pins _is_force_float32_arch against the
real-world model_type strings on the Hub.
@danielhanchen danielhanchen force-pushed the warn-bf16-fp16-downcast branch from ae74d0e to c3af3b2 Compare May 19, 2026 12:35

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: c3af3b23ee

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".


def test_force_float32_list_exported():
"""FORCE_FLOAT32 is importable from the top-level unsloth_zoo namespace."""
import unsloth_zoo

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Avoid importing the top-level package in this unit test

When this new test module is run in the repo's CPU-only harness without the separate unsloth package installed, this import executes unsloth_zoo.__init__ and hits its find_spec("unsloth") guard before the assertion runs; I reproduced this with pytest tests/test_mlx_dtype_downcast_warning.py -q where this test fails while the submodule-based tests pass. Importing the dependency-free unsloth_zoo.model_lists module directly would keep the test exercising the new list without requiring a full Unsloth install.

Useful? React with 👍 / 👎.

@danielhanchen danielhanchen merged commit b89d8c4 into main May 19, 2026
1 of 14 checks passed
@danielhanchen danielhanchen deleted the warn-bf16-fp16-downcast branch May 19, 2026 12:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant