Skip to content

Ministral-3-8B-Instruct tokenizer doesn't handle BPE markers properly #42796

@yanancai

Description

@yanancai

System Info

from transformers import AutoTokenizer, AutoModelForImageTextToText, AutoProcessor
import torch

base_model = "mistralai/Ministral-3-8B-Instruct-2512-BF16"

model = AutoModelForImageTextToText.from_pretrained(base_model, dtype=torch.bfloat16)
model = model.to("cuda:1")
tokenizer = AutoProcessor.from_pretrained(base_model)

user_prompt = "hello how are you?"
messages = [
    {"role": "user", "content": user_prompt},
]

text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(text=text, return_tensors="pt").to(model.device, dtype=torch.bfloat16)
generate_ids = model.generate(**inputs, max_new_tokens=50, do_sample=False)
decoded_output = tokenizer.batch_decode(generate_ids[:, inputs["input_ids"].shape[1] :], skip_special_tokens=True)[0]
print(decoded_output)

Output:

Hello!ĠðŁĺĬĠI'mĠjustĠaĠvirtualĠassistant,ĠsoĠIĠdon'tĠhaveĠfeelings,ĠbutĠI'mĠhereĠandĠreadyĠtoĠhelpĠyouĠwithĠanythingĠyouĠneed!ĠHowĠaboutĠyouâĢĶhowĠareĠ*you*ĠdoingĠtoday?ĠAnythingĠfunĠorĠinterestingĠon

Environments:
Python 3.12.7
transformers 5.0.0.dev0 (installed from main branch)
torch: 2.9.0
mistral_common: 1.8.6

The same code with MinistralCommonBackend loaded tokenizer works:
Code:

import torch
from transformers import AutoModelForImageTextToText, MistralCommonBackend


tokenizer = MistralCommonBackend.from_pretrained(base_model)
model = AutoModelForImageTextToText.from_pretrained(
    base_model, torch_dtype=torch.bfloat16
)
model = model.to("cuda:2")

user_prompt = "hello how are you?"
messages = [
    {"role": "user", "content": user_prompt},
]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(text=text, return_tensors="pt").to(model.device)
generate_ids = model.generate(**inputs, max_new_tokens=50, do_sample=False)
decoded_output = tokenizer.batch_decode(generate_ids[:, inputs["input_ids"].shape[1] :], skip_special_tokens=True)[0]
print(decoded_output)

Output:

 Hello! I'm just a program, so I don't have feelings, but I'm here and ready to help you with anything you need. How about you? How are you doing today?[😊]

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

from transformers import AutoTokenizer, AutoModelForImageTextToText, AutoProcessor
import torch

base_model = "mistralai/Ministral-3-8B-Instruct-2512-BF16"

model = AutoModelForImageTextToText.from_pretrained(base_model, dtype=torch.bfloat16)
model = model.to("cuda:1")
tokenizer = AutoProcessor.from_pretrained(base_model)

user_prompt = "hello how are you?"
messages = [
    {"role": "user", "content": user_prompt},
]

text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(text=text, return_tensors="pt").to(model.device, dtype=torch.bfloat16)
generate_ids = model.generate(**inputs, max_new_tokens=50, do_sample=False)
decoded_output = tokenizer.batch_decode(generate_ids[:, inputs["input_ids"].shape[1] :], skip_special_tokens=True)[0]
print(decoded_output)

Expected behavior

Clean output with BPE markers handled properly

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions