Skip to content

fix(omlx): faithful bge serving on MLX — reranker bf16 load + embedding eval-mode & CLS pooling#1767

Merged
jundot merged 3 commits into
jundot:mainfrom
paalolav:fix/reranker-bf16-mx-load
Jun 10, 2026
Merged

fix(omlx): faithful bge serving on MLX — reranker bf16 load + embedding eval-mode & CLS pooling#1767
jundot merged 3 commits into
jundot:mainfrom
paalolav:fix/reranker-bf16-mx-load

Conversation

@paalolav

@paalolav paalolav commented Jun 9, 2026

Copy link
Copy Markdown
Contributor

Problem

Loading an XLMRobertaForSequenceClassification reranker whose weights are stored in bfloat16 crashes at load time:

File "omlx/models/reranker.py", line 166, in _load_xlm_roberta
    weights[key] = f.get_tensor(key)
TypeError: data type 'bfloat16' not understood
POST /v1/rerank → 500 (unhandled): data type 'bfloat16' not understood

safetensors.safe_open(wf, framework="mlx").get_tensor(key) still routes the tensor through numpy, and numpy has no bfloat16 dtype — so any bf16 reranker (e.g. a bf16 BAAI/bge-reranker-v2-m3 export) returns 500 on every /v1/rerank call.

Fix

Use mx.load(str(path)), which reads a safetensors file directly into MLX arrays and supports bfloat16 natively. mx is already imported in both functions, so the change is minimal. fp16/fp32 weights are unaffected (mx.load handles all three), and the now-unused local from safetensors import safe_open imports are removed.

Two call sites share the identical pattern and are both fixed:

  • _load_xlm_roberta (the reported reranker path)
  • the JinaForRanking projector loader

Testing

  • _load_xlm_roberta: verified on Apple Silicon (M2) with a bf16 bge-reranker-v2-m3 export — /v1/rerank previously 500'd on every call, and after this change returns 200 with correct relevance scores (relevant doc 0.84 vs irrelevant ~0.00007). Back-to-back 24-document reranks are stable at Metal speed.
  • JinaForRanking projector: not hardware-tested (no bf16 JinaForRanking model on hand), but it is the identical safe_open + get_tensormx.load substitution and is dtype-agnostic.

paalolav added 3 commits June 9, 2026 12:05
XLMRobertaForSequenceClassification reranker models (and the JinaForRanking
projector) fail to load when their weights are stored in bfloat16:

    File "omlx/models/reranker.py", in _load_xlm_roberta
    TypeError: data type 'bfloat16' not understood

safetensors safe_open(framework="mlx").get_tensor() routes the tensor through
numpy, which has no bfloat16 dtype. mx.load() reads the safetensors file
directly into MLX arrays and supports bfloat16 natively. Behaviour for
fp16/fp32 weights is unchanged (mx.load handles all three).
Native XLMRoberta/BERT embedding models were loaded without switching to
eval mode, leaving dropout (p>0) active. Every /v1/embeddings call then
applied random dropout, producing non-deterministic, corrupted vectors
(same input -> cosine ~0.90 between calls). Set model.train(False) after
load so embeddings are deterministic.
…LM-R/BERT

The native embedding path hardcoded mean pooling. Models trained for CLS
pooling (e.g. BAAI bge-m3 dense — mean pooling significantly degrades it,
per the BGE authors) were served with the wrong pooling, so vectors did
not match the reference implementation. Honor the existing-but-unused
ModelArgs.pooling_config.pooling_mode; default 'mean' preserves behavior.
@paalolav paalolav changed the title fix(reranker): load safetensors via mx.load to support bfloat16 weights fix(omlx): faithful bge serving on MLX — reranker bf16 load + embedding eval-mode & CLS pooling Jun 9, 2026
@paalolav

paalolav commented Jun 9, 2026

Copy link
Copy Markdown
Contributor Author

Added two embedding-path fixes on top of the reranker bf16 change (same theme: serve BAAI bge models faithfully on MLX):

1. fed62be — disable dropout in native embedding load. MLXEmbeddingModel._load_native never switched the model to eval mode, so XLM-RoBERTa/BERT dropout (p>0) stayed active. Every /v1/embeddings call applied random dropout → non-deterministic output: the same input returned cosine ~0.90 between calls. Fix: model.train(False) after load. (Verified: determinism cosine 1.0 after.)

2. 6498a23 — honor pooling_config.pooling_mode. The native path hardcoded mean pooling, but ModelArgs already carries an (unused) pooling_config. BAAI bge-m3 dense is trained on the [CLS] token — the BGE authors note mean pooling significantly degrades it — so mean-pooled vectors don't match the reference. Now honors pooling_config.pooling_mode (default mean preserves existing behavior; set cls in the model's config.json).

Together these make oMLX's bge-m3 embeddings bit-exact to FlagEmbedding (query parity cosine 1.0 vs the reference CLS@512), which is what let us move query embedding onto Metal with no re-ingest. Happy to split into a separate PR if you'd prefer to keep this one reranker-only.

@jundot

jundot commented Jun 10, 2026

Copy link
Copy Markdown
Owner

Thanks for the fix and for the detailed verification notes.

I verified the core bf16 issue: mx.load() handles bf16 safetensors correctly, while safe_open(...).get_tensor() still fails on that dtype. The eval-mode and CLS pooling changes for native embeddings also make sense.

I found two related follow-up gaps while checking the same paths: the native embedding loader still uses the same safe_open bf16 load path, and the native XLM-RoBERTa reranker currently stays in training mode, so dropout can make repeated rerank scores nondeterministic once the bf16 load succeeds.

This is useful as-is, and I will merge it. I will fold those consistency fixes into a follow-up on main.

@jundot jundot merged commit a2cb234 into jundot:main Jun 10, 2026
khsd6327 added a commit to khsd6327/omlx that referenced this pull request Jun 10, 2026
The native embedding loader only read keys from top-level config.json, so
ModelArgs.pooling_config stayed None and bge-m3 silently fell back to mean
pooling instead of CLS (the jundot#1767 'faithful bge serving' fix was dead code).
Resolve pooling_mode from 1_Pooling/config.json (cls/mean/max mapping),
best-effort and fail-soft. Adds regression tests.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants