Skip to content

🚨 [v5] generate delegates default cache initialization to the model#41505

Merged
gante merged 10 commits into
huggingface:mainfrom
gante:delegate_cache_init
Oct 13, 2025
Merged

🚨 [v5] generate delegates default cache initialization to the model#41505
gante merged 10 commits into
huggingface:mainfrom
gante:delegate_cache_init

Conversation

@gante

@gante gante commented Oct 10, 2025

Copy link
Copy Markdown
Contributor

What does this PR do?

See PR title.

Now that all traces of legacy caches were removed, we can trust the model to initialize its own cache! This means we no longer need to set cache_implementation="xxx" defaults in new models, assuming the model's forward pass defaults to the right cache class.

Also fixes related bugs, uncovered by not feeding a cache to the model.

@gante gante changed the title [generate] delegate default cache initialization to the model 🚨 [v5] generate delegates default cache initialization to the model Oct 10, 2025
@gante gante mentioned this pull request Oct 10, 2025
@HuggingFaceDocBuilderDev

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

if requires_cross_attention_cache and not isinstance(model_kwargs[cache_name], EncoderDecoderCache):
model_kwargs[cache_name] = EncoderDecoderCache(
model_kwargs[cache_name], # self-attention cache
if (

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update to this branch -- we only want to convert the cache to EncoderDecoderCache here if:

  1. the user has set custom cache args in generate
  2. the model is encoder-decoder
  3. (implicitly, all encoder-decoder models have past_key_values as their cache name)

In all other cases, we delegate cache init to the model itself

@@ -546,8 +546,10 @@ def prepare_inputs_for_generation(
model_inputs["cache_position"] = cache_position

# 2. Generic cache-dependent input preparation

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes in new L549-594 are related to recurrent_gemma: prior to the deletion of the old L2002-2007, generate was preparing a DynamicCache for recurrent_gemma. This cache was never used in forward, but it was inducing the correct behavior in prepare_inputs_for_generation (cache_position-based input slicing)

With the new logic, use_cache=True implies cache_position-based input slicing, even if the model is not using a standard cache.

past_key_values = (
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
if encoder_hidden_states is not None
if encoder_hidden_states is not None or self.config.is_encoder_decoder

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On encoder-decoder models that we may want to use as decoder-only, we want EncoderDecoderCache in two possible situations:

  1. [missing] the model is encoder-decoder
  2. the model is not encoder-decoder, but encoder_hidden_states passed (which means we will compute the cross-attention, and thus we should cache it)

choice_labels,
):
config = copy.deepcopy(config)
config.is_decoder = True

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RoFormerForCausalLM won't use the cache if config.is_decoder!=True, and this test tests cache usage 🙃

(A warning was being thrown)

Comment on lines +2194 to +2197
# build `cache_position` on the fly
seq_length = inputs["input_ids"].shape[1]
inputs = self.model._get_initial_cache_position(seq_length, self.model.device, inputs)
# prepare other inputs

@gante gante Oct 11, 2025

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whisper has custom generation structure that doesn't follow our code patterns -> one of the issues is that it has stateful LogitsProcessor -> this state-related function uses prepare_inputs_for_generation out of the usual order -> changes related to this PR exposed that it was missing cache_position as an input here

(if we have bandwidth, we should revisit whisper generate to streamline its code)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(if we have bandwidth, we should revisit whisper generate to streamline its code)

💯, revisiting audio modality generation will be super helpful

@Cyrilvallez Cyrilvallez left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!!

Comment thread src/transformers/generation/utils.py Outdated
Comment on lines +1808 to +1810
if model_kwargs.get("past_key_values") is not None:
if model_kwargs.get("past_key_values", None) is not None:

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically not needed, get has None as a default (and I thought ruff was now enforcing this one but apparently not 🤔 too many changes to the ruff rules recently)

Comment on lines -619 to -622
# initialize `past_key_values`
if use_cache and past_key_values is None:
past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EncoderDecoderCache initialized in a Decoder only module... 🥲🥲🥲 good catch!

@zucchini-nlp zucchini-nlp left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, left one question that seems to be a typo or just my bad understading

Comment on lines +2194 to +2197
# build `cache_position` on the fly
seq_length = inputs["input_ids"].shape[1]
inputs = self.model._get_initial_cache_position(seq_length, self.model.device, inputs)
# prepare other inputs

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(if we have bandwidth, we should revisit whisper generate to streamline its code)

💯, revisiting audio modality generation will be super helpful

Comment thread src/transformers/generation/utils.py Outdated
model_inputs["cache_position"] = cache_position

# 2. Generic cache-dependent input preparation
use_cache = kwargs.get("use_cache", False) or getattr(self.config, "use_cache", False)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will result in True even when the user-kwargs set caching as False. IIUC we just want to give priority to user-defined cache and not assume that caching is used whenever it is set to True in any of the places

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zucchini-nlp that's a good catch, user-defined generate kwargs >> config values!

Will update accordingly

use_cache = kwargs.get("use_cache", False) or getattr(self.config, "use_cache", False)
if past_key_values is not None:
model_inputs["past_key_values"] = past_key_values
if past_key_values is None or use_cache:

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe i am missing smth, do we apply cache slicing when the past_key_values is None? Looks not intuitive from first sight, so let's add a comment explaining why

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stateful models like recurrent_gemma assume that slicing happens, but don't have a Cache cache -- will add a comment :)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so they don't have a Cache object and also if use_cache does not catch those cases?

@github-actions

Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: bart, bert, bert_generation, bigbird_pegasus, blenderbot, blenderbot_small, camembert, data2vec, electra, ernie, fsmt, kosmos2, marian, mbart, mvp, pegasus

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants