System Info
transformers version: 5.8.0 (also reproduced on 5.0.0 through 5.7.0)
- Platform: Linux-5.14.0-503.11.1.el9_5.x86_64-x86_64-with-glibc2.34
- Python version: 3.12.13
- Huggingface_hub version: 1.14.0
- Safetensors version: 0.7.0
- Tokenizers version: 0.22.2
- PyTorch version: not installed (tokenizer-only reproduction)
- Using distributed or parallel set-up in script?: No
Who can help?
@ArthurZucker and @itazap
Information
Tasks
Reproduction
from transformers import AutoTokenizer, PreTrainedTokenizerFast
model_id = "ibm-granite/granite-4.1-8b" # any Granite 4+ model
tok_auto = AutoTokenizer.from_pretrained(model_id)
tok_correct = PreTrainedTokenizerFast.from_pretrained(model_id)
# Numeric strings are most visibly affected (digit splitting differs)
print(tok_auto.encode("2023", add_special_tokens=False)) # [508, 1419] ← WRONG
print(tok_correct.encode("2023", add_special_tokens=False)) # [2366, 18] ← correct
print(tok_auto.encode("650841823", add_special_tokens=False)) # [13655, 5833, 972, 1419] ← WRONG
print(tok_correct.encode("650841823", add_special_tokens=False)) # [13655, 25496, 23848] ← correct
print(tok_auto.encode("ISO 9001:2015", add_special_tokens=False)) # [25141, 220, 24, 4119, 25, 679, 20] ← WRONG
print(tok_correct.encode("ISO 9001:2015", add_special_tokens=False)) # [25141, 220, 7467, 16, 25, 679, 20] ← correct
Pre-tokenizer mismatch:
print(tok_auto.backend_tokenizer.pre_tokenizer)
# ByteLevel(add_prefix_space=False, trim_offsets=True, use_regex=True) ← WRONG
print(tok_correct.backend_tokenizer.pre_tokenizer)
# Sequence([Split(regex, ...), ByteLevel(..., use_regex=False)]) ← correct (matches tokenizer.json)
Expected behavior
Granite models ship a tokenizer.json with a Sequence(Split(\p{N}{1,3}) + ByteLevel) pre-tokenizer (tiktoken/cl100k style — splits digit runs into ≤3-character chunks before BPE). But AutoTokenizer routes Granite to GPT2Tokenizer, whose __init__ hardcodes ByteLevel(use_regex=True) — which uses the GPT-2 regex (\p{N}+, keeps all digits together).
Since BPE was trained with \p{N}{1,3} splitting, using a different pre-tokenizer at inference produces wrong merges and wrong token IDs.
Workaround
Use PreTrainedTokenizerFast directly (works in both v4 and v5):
from transformers import PreTrainedTokenizerFast
tok = PreTrainedTokenizerFast.from_pretrained("ibm-granite/granite-4.1-8b")
Proposed Fix
Change TOKENIZER_MAPPING_NAMES for Granite model types from "GPT2Tokenizer" to "TokenizersBackend":
- ("granite", "GPT2Tokenizer"),
- ("granitemoe", "GPT2Tokenizer"),
- ("granitemoehybrid", "GPT2Tokenizer"),
- ("granitemoeshared", "GPT2Tokenizer"),
+ ("granite", "TokenizersBackend" if is_tokenizers_available() else None),
+ ("granitemoe", "TokenizersBackend" if is_tokenizers_available() else None),
+ ("granitemoehybrid", "TokenizersBackend" if is_tokenizers_available() else None),
+ ("granitemoeshared", "TokenizersBackend" if is_tokenizers_available() else None),
This triggers the existing mismatch detection (mapping says "TokenizersBackend" ≠ hub's "GPT2Tokenizer"), which falls through to TokenizersBackend.from_pretrained() — loading tokenizer.json faithfully.
Why not MODELS_WITH_INCORRECT_HUB_TOKENIZER_CLASS? That override set is only consulted inside the mismatch block (line 733) — which requires TOKENIZER_MAPPING_NAMES[model_type] to disagree with the hub's tokenizer_class. For Granite, both say "GPT2Tokenizer", so no mismatch is detected and the override set is never checked.
Related Issues
System Info
transformersversion: 5.8.0 (also reproduced on 5.0.0 through 5.7.0)Who can help?
@ArthurZucker and @itazap
Information
Tasks
examplesfolder (such as GLUE/SQuAD, ...)Reproduction
Pre-tokenizer mismatch:
Expected behavior
Granite models ship a
tokenizer.jsonwith aSequence(Split(\p{N}{1,3}) + ByteLevel)pre-tokenizer (tiktoken/cl100k style — splits digit runs into ≤3-character chunks before BPE). ButAutoTokenizerroutes Granite toGPT2Tokenizer, whose__init__hardcodesByteLevel(use_regex=True)— which uses the GPT-2 regex (\p{N}+, keeps all digits together).Since BPE was trained with
\p{N}{1,3}splitting, using a different pre-tokenizer at inference produces wrong merges and wrong token IDs.Workaround
Use
PreTrainedTokenizerFastdirectly (works in both v4 and v5):Proposed Fix
Change
TOKENIZER_MAPPING_NAMESfor Granite model types from"GPT2Tokenizer"to"TokenizersBackend":This triggers the existing mismatch detection (mapping says
"TokenizersBackend"≠ hub's"GPT2Tokenizer"), which falls through toTokenizersBackend.from_pretrained()— loadingtokenizer.jsonfaithfully.Why not
MODELS_WITH_INCORRECT_HUB_TOKENIZER_CLASS? That override set is only consulted inside the mismatch block (line 733) — which requiresTOKENIZER_MAPPING_NAMES[model_type]to disagree with the hub'stokenizer_class. For Granite, both say"GPT2Tokenizer", so no mismatch is detected and the override set is never checked.Related Issues
LlamaTokenizerhardcodesMetaspace, breaks DeepSeek V3/R1 (GSM8K drops 63.7% → 26.1%)AutoTokenizerignorestokenizer.jsonpre-tokenizer for deepseek-coder