Skip to content

Fix tokenizer auto_map being ignored for custom models#43219

Merged
vasqu merged 3 commits intohuggingface:mainfrom
Anri-Lombard:fix-tokenizer-auto-map-regression
Jan 22, 2026
Merged

Fix tokenizer auto_map being ignored for custom models#43219
vasqu merged 3 commits intohuggingface:mainfrom
Anri-Lombard:fix-tokenizer-auto-map-regression

Conversation

@Anri-Lombard
Copy link
Copy Markdown
Contributor

Fixes #43202

PR #42894 introduced an early-exit to TokenizersBackend when tokenizer_class doesn't match the registered tokenizer for a model_type. However, this check was placed before the auto_map extraction, causing custom tokenizers (with trust_remote_code=True) to be ignored.

Reproduction:

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("IQuestLab/IQuest-Coder-V1-40B-Loop-Instruct", trust_remote_code=True)
print(tokenizer.decode(tokenizer.encode("This is a test")))
# Expected: "This is a test"
# Actual (bug): "Th is <unk> is <unk> a <unk> te st"

For models with unregistered model_type (like iquestloopcoder), the condition TOKENIZER_MAPPING_NAMES.get(model_type, "") != tokenizer_config_class is always True, causing early-exit to TokenizersBackend without checking if auto_map exists.

This fix moves the auto_map extraction before the early-exit check and adds tokenizer_auto_map is None to the condition.

Added regression test test_custom_tokenizer_with_mismatched_tokenizer_class.

…3202)

PR huggingface#42894 added an early-exit to TokenizersBackend when tokenizer_class
doesn't match the registered tokenizer for a model_type. However, this
early-exit was placed before the auto_map check, causing custom tokenizers
with trust_remote_code to be ignored.

This fix moves the auto_map extraction before the early-exit check and adds
tokenizer_auto_map is None to the condition, so models with custom tokenizers
properly use the dynamic module loading path.
@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Jan 12, 2026

cc @itazap @ArthurZucker

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me!
Thankd for also adding a test!

@awni
Copy link
Copy Markdown

awni commented Jan 21, 2026

@Anri-Lombard @ArthurZucker thanks for the fix here. We are depending on it in mlx-lm and looking forward to the next RC.

@ArthurZucker
Copy link
Copy Markdown
Collaborator

Can you fix the quality checks please?! 🤗

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Jan 22, 2026

@bot /style

@github-actions
Copy link
Copy Markdown
Contributor

Style fix is beginning .... View the workflow run here.

@vasqu vasqu enabled auto-merge (squash) January 22, 2026 17:16
@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@vasqu vasqu merged commit 5410465 into huggingface:main Jan 22, 2026
25 checks passed
@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Jan 22, 2026

I was so free to make the style check myself to merge, thanks for contributing ❤️

SangbumChoi pushed a commit to SangbumChoi/transformers that referenced this pull request Jan 23, 2026
…3219)

* Fix tokenizer auto_map being ignored for custom models (huggingface#43202)

PR huggingface#42894 added an early-exit to TokenizersBackend when tokenizer_class
doesn't match the registered tokenizer for a model_type. However, this
early-exit was placed before the auto_map check, causing custom tokenizers
with trust_remote_code to be ignored.

This fix moves the auto_map extraction before the early-exit check and adds
tokenizer_auto_map is None to the condition, so models with custom tokenizers
properly use the dynamic module loading path.

* style

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: vasqu <antonprogamer@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

tokenizer.decode producing bad results in some cases from 5.0.0rc1 to 5.0.0rc2

5 participants