Skip to content

fix: Support scikit-learn 1.9 in ZeroShotClassification#4790

Merged
KennethEnevoldsen merged 2 commits into
embeddings-benchmark:mainfrom
tejasnaladala:fix/zeroshot-sklearn-19
Jun 10, 2026
Merged

fix: Support scikit-learn 1.9 in ZeroShotClassification#4790
KennethEnevoldsen merged 2 commits into
embeddings-benchmark:mainfrom
tejasnaladala:fix/zeroshot-sklearn-19

Conversation

@tejasnaladala

Copy link
Copy Markdown
Contributor

Fixes #4784

scikit-learn 1.9 raises on mixed-type inputs to classification metrics (scikit-learn#33086). AbsTaskZeroShotClassification hits exactly that case: predictions are always integer indices into the candidate labels, while the dataset label column can contain strings. On main with scikit-learn 1.9.0, tests/test_abstasks/test_predictions.py::test_predictions[task2-expected2] fails with:

ValueError: Mix of label input types (string and number); Got ['label1' 'label2'] and [0 1]

The pre-1.9 behavior was arguably worse: accuracy_score(["label1", "label2"], [0, 1]) compared strings to ints, never matched, and silently reported 0.0 accuracy.

Changes

  • AbsTaskZeroShotClassification._normalize_labels: integer labels pass through unchanged, string labels are mapped to their index in get_candidate_labels(), and strings that match no candidate raise a ValueError describing the label contract. For external task authors this is a behavior change: unmappable string labels now fail loudly on every scikit-learn version instead of silently scoring 0.0.
  • Removed the scikit-learn<1.9.0 stopgap pin from fix: Update lock and remove python limit fo pylate and colbert_engine #4783 and regenerated the lock with uv lock --upgrade-package scikit-learn. Since scikit-learn 1.9 dropped Python 3.10, the lock forks to 1.9.0 on Python >= 3.11 and stays on 1.7.2 for 3.10, so the CI matrix exercises both sides.
  • Mocks: the image zeroshot mock now uses integer labels like the audio and video mocks; the text zeroshot mock keeps string labels (matching its candidates) so the mapping path stays covered in CI.
  • New tests/test_abstasks/test_zeroshot_classification.py: unit tests for the passthrough, mapping, and error contract, plus an end-to-end regression test that evaluates mteb/baseline-random-encoder on the string-label mock and asserts accuracy == 1.0 (deterministic; the encoder seeds embeddings per input string).

Leaderboard impact

None. Predictions are untouched and label normalization is an identity for integer labels. All registered zeroshot tasks store labels as integer indices: most derive candidates from ClassLabel features, SciMMIR maps strings to ints in dataset_transform, and the remaining label columns were verified as int64 via the HF datasets server (mteb/esc50, mteb/SpeechCommandsZeroshotv0.01, mteb/urbansound8K, mteb/wds_imagenet1k, and others). The string path only ever fired for the test mocks.

Verification

  • Reproduced the failure on main with scikit-learn 1.9.0, then confirmed it passes with this branch (red/green on the same test).
  • tests/test_abstasks (154 passed), tests/test_integrations + prompt validation (314 passed), tests/test_evaluators, tests/test_evaluate.py, tests/test_result_cache.py (87 passed) all green locally on scikit-learn 1.9.0.
  • ruff format --check, ruff check, typos, and mypy mteb (mypy 2.1.0) clean; uv lock --check passes.

scikit-learn 1.9 raises "ValueError: Mix of label input types" when
classification metrics receive string y_true with numeric y_pred.
Zeroshot predictions are always integer indices into the candidate
labels, so string dataset labels are now mapped to their candidate
index before scoring. Unmappable string labels raise a clear error
instead of silently scoring 0.0, which is what scikit-learn < 1.9 did.

Removes the <1.9.0 pin introduced as a stopgap in embeddings-benchmark#4783.

Fixes embeddings-benchmark#4784
@Samoed

Samoed commented Jun 9, 2026

Copy link
Copy Markdown
Member

Can you run some tasks to verify that scores are matching?

@tejasnaladala

Copy link
Copy Markdown
Contributor Author

Ran openai/clip-vit-base-patch32 on two real zeroshot tasks (fresh runs, cache=None, same machine):

Task main (5cfdb8e) + scikit-learn 1.7.2 this PR + scikit-learn 1.9.0
RenderedSST2 0.5848434925864909 0.5848434925864909
DTDZeroShot 0.4186170212765957 0.4186170212765957

Identical to full float precision, as expected: _normalize_labels is an identity passthrough for integer labels, all registered zeroshot tasks have integer label columns, and predictions are untouched.

The RenderedSST2 value also matches the published result for this model in embeddings-benchmark/results (results/openai__clip-vit-base-patch32/3d74acf9.../RenderedSST2.json, accuracy 0.5848434925864909 from mteb 1.14.15) digit for digit. The published DTDZeroShot value (0.41914893617021276) was produced on dataset revision d2afa97d, while the task now pins 96726183, so that one is not the same data; main and this branch agree exactly on the current revision.

Repro:

import mteb

model = mteb.get_model_meta("openai/clip-vit-base-patch32")
tasks = mteb.get_tasks(tasks=["RenderedSST2", "DTDZeroShot"])
results = mteb.evaluate(model, tasks, cache=None, co2_tracker=False)
print({r.task_name: r.scores["test"][0]["accuracy"] for r in results})

@Samoed Samoed requested a review from KennethEnevoldsen June 9, 2026 23:38
@tejasnaladala

Copy link
Copy Markdown
Contributor Author

Note on the one red check: test-dockerfile also fails on main itself (pushes at 16:06 and 16:48 UTC today) and on the other open PRs, so it is the pre-existing #4292 flake, not this diff. Everything else is green, including the py3.11-3.14 jobs that run scikit-learn 1.9.0.

@KennethEnevoldsen KennethEnevoldsen enabled auto-merge (squash) June 10, 2026 10:52
@KennethEnevoldsen KennethEnevoldsen merged commit 5e8f0e1 into embeddings-benchmark:main Jun 10, 2026
13 of 14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Sklearn 1.9 is not compatible with ZeroShotClassification

3 participants