Skip to content

[generate] add faster stop_strings stopping criteria#40520

Open
gante wants to merge 6 commits into
huggingface:mainfrom
gante:stopstring
Open

[generate] add faster stop_strings stopping criteria#40520
gante wants to merge 6 commits into
huggingface:mainfrom
gante:stopstring

Conversation

@gante

@gante gante commented Aug 28, 2025

Copy link
Copy Markdown
Contributor

What does this PR do?

Adds StopStringTextMatchCriteria, a faster alternative to StopStringCriteria. Unlike StopStringCriteria, StopStringTextMatchCriteria can't be compiled.

Some additional context:

  • when we added StopStringCriteria, we were looking forward having end-to-end generate compilation, so it made sense to focus on compilable options;
  • As a user mentioned on this issue, StopStringCriteria can be really slow in some contexts. More specifically, at initialization time (see benchamrks below);
  • In general, StopStringTextMatchCriteria is faster, so it's the new default. StopStringCriteria is kept for torch.compile users.

Thank you @MaxBourdon for surfacing the problem

Benchmarks

TL;DR StopStringCriteria is very slow to initialize on new stop_strings inputs, >2s on my machine. This is cached, so successive calls with the same stop_strings are not as bad. However, it's particularly troublesome when trying small models, as this initialization may take much more than the generation time. Excluding init time, the new StopStringTextMatchCriteria is also slightly faster.

Benchmark script
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteriaList, StopStringTextMatchCriteria, StopStringCriteria
from time import time
import torch

N_RUNS = 100
MAX_NEW_TOKENS = 100
MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct"
STOP_STRINGS = ["Potato", "Carrots", "Onions", "Garlic", "Tomatoes", "Lettuce", "Cucumbers"]
BATCH_SIZE = 1

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map="auto", dtype=torch.bfloat16)
inputs = tokenizer(["The quick brown"] * BATCH_SIZE, return_tensors="pt").to(model.device)

# warmup
for i in range(10):
    gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=MAX_NEW_TOKENS, min_new_tokens=MAX_NEW_TOKENS)
    assert gen_out.shape[1] == MAX_NEW_TOKENS + inputs.input_ids.shape[1]

# ------------------------------------------------
# No stopping criteria
# ------------------------------------------------
all_times = []
for i in range(N_RUNS):
    start_time = time()
    gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=MAX_NEW_TOKENS, min_new_tokens=MAX_NEW_TOKENS)
    assert gen_out.shape[1] == MAX_NEW_TOKENS + inputs.input_ids.shape[1]
    end_time = time()
    all_times.append(end_time - start_time)

avg_time = sum(all_times) / N_RUNS
print(f"[No stopping criteria] Average generation time: {avg_time} seconds")


# ------------------------------------------------
# StopStringCriteria
# ------------------------------------------------
# IMPORTANT NOTE: initializing this for the first time is slow, >1s. Prior to this PR, this initialization was
# done inside `generate` when `stop_strings` is set
init_start_time = time()
custom_stopping_criteria = StoppingCriteriaList([StopStringCriteria(tokenizer, STOP_STRINGS)])
init_end_time = time()
print(f"[StopStringCriteria] first init time: {init_end_time - init_start_time} seconds")

init_start_time = time()
custom_stopping_criteria = StoppingCriteriaList([StopStringCriteria(tokenizer, STOP_STRINGS)])
init_end_time = time()
print(f"[StopStringCriteria] second init time: {init_end_time - init_start_time} seconds")

all_times = []
for i in range(N_RUNS):
    start_time = time()
    gen_out = model.generate(
        **inputs,
        do_sample=False,
        stopping_criteria=custom_stopping_criteria,
        max_new_tokens=MAX_NEW_TOKENS,
        min_new_tokens=MAX_NEW_TOKENS,
    )
    assert gen_out.shape[1] == MAX_NEW_TOKENS + inputs.input_ids.shape[1]
    end_time = time()
    all_times.append(end_time - start_time)

avg_time = sum(all_times) / N_RUNS
print(f"[StopStringCriteria] Average generation time: {avg_time} seconds")

# ------------------------------------------------
# StopStringTextMatchCriteria
# ------------------------------------------------
init_start_time = time()
custom_stopping_criteria = StoppingCriteriaList([StopStringTextMatchCriteria(tokenizer, STOP_STRINGS)])
init_end_time = time()
print(f"[StopStringTextMatchCriteria] first init time: {init_end_time - init_start_time} seconds")

init_start_time = time()
custom_stopping_criteria = StoppingCriteriaList([StopStringTextMatchCriteria(tokenizer, STOP_STRINGS)])
init_end_time = time()
print(f"[StopStringTextMatchCriteria] second init time: {init_end_time - init_start_time} seconds")

all_times = []
for i in range(N_RUNS):
    start_time = time()
    gen_out = model.generate(
        **inputs,
        do_sample=False,
        stopping_criteria=custom_stopping_criteria,
        max_new_tokens=MAX_NEW_TOKENS,
        min_new_tokens=MAX_NEW_TOKENS,
    )
    assert gen_out.shape[1] == MAX_NEW_TOKENS + inputs.input_ids.shape[1]
    end_time = time()
    all_times.append(end_time - start_time)

avg_time = sum(all_times) / N_RUNS
print(f"[StopStringTextMatchCriteria] Average generation time: {avg_time} seconds")


# ------------------------------------------------
# generate with stop strings (using `StopStringTextMatchCriteria` under the hood)
# ------------------------------------------------
all_times = []
for i in range(N_RUNS):
    start_time = time()
    gen_out = model.generate(
        **inputs,
        do_sample=False,
        stop_strings=STOP_STRINGS,
        tokenizer=tokenizer,
        max_new_tokens=MAX_NEW_TOKENS,
        min_new_tokens=MAX_NEW_TOKENS,
    )
    assert gen_out.shape[1] == MAX_NEW_TOKENS + inputs.input_ids.shape[1]
    end_time = time()
    all_times.append(end_time - start_time)

avg_time = sum(all_times) / N_RUNS
print(f"[Default `stop_strings` criteria] Average generation time: {avg_time} seconds")

Benchmark results on my machine:

[No stopping criteria] Average generation time: 1.3314339590072632 seconds
[StopStringCriteria] first init time: 2.4044578075408936 seconds  # <------- this is the issue, init time > generation time!
[StopStringCriteria] second init time: 0.0518953800201416 seconds
[StopStringCriteria] Average generation time: 1.3567428421974181 seconds
[StopStringTextMatchCriteria] first init time: 6.4373016357421875e-06 seconds
[StopStringTextMatchCriteria] second init time: 1.9073486328125e-06 seconds
[StopStringTextMatchCriteria] Average generation time: 1.343175311088562 seconds
[Default `stop_strings` criteria] Average generation time: 1.3437320828437804 seconds  # <---- uses `StopStringTextMatchCriteria`

self.assertEqual(len(stopping_criteria), 1)

def test_stop_string_criteria(self):
@parameterized.expand(

@gante gante Aug 28, 2025

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(reuses the extensive tests for StopStringCriteria on StopStringTextMatchCriteria, ensuring 1:1 compatibility)

@HuggingFaceDocBuilderDev

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment thread src/transformers/generation/stopping_criteria.py Outdated
Comment thread src/transformers/generation/stopping_criteria.py Outdated
Comment thread src/transformers/generation/stopping_criteria.py
@MaxBourdon

Copy link
Copy Markdown

Thank you for your PR!
I added some comments on the implementation details, hope they could be useful 😄

@gante

gante commented Aug 29, 2025

Copy link
Copy Markdown
Contributor Author

@MaxBourdon thank you for the feedback, corrected 😉

(there was another edge case failing: the last token fully contains the stop string, but doesn't start with stop string characters; added a test)

Comment on lines +566 to +567
last_two_tokens_text = self.tokenizer.decode(input_ids[batch_idx, -2:])
last_tokens_with_prefix_text = self.tokenizer.decode(input_ids[batch_idx, -1:])

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pretty sure we shuld be using the decode stream here!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is that? 👀

I see some related docs (https://huggingface.co/docs/tokenizers/v0.20.3/en/api/decoders#tokenizers.decoders.DecodeStream), but they lead nowhere

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://huggingface.co/docs/tokenizers/main/en/api/decoders#tokenizers.decoders.DecodeStream
What I am saying is that decoding stuff like a brute here is risking not following the "stream" (acting on tokens and not strings).

decode_stream = DecodeStream(_inputs.tolist(), False)
you init with prefix, then you step into it to get the string.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants