🚨 [v5] generate delegates default cache initialization to the model#41505
Conversation
generate delegates default cache initialization to the model
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| if requires_cross_attention_cache and not isinstance(model_kwargs[cache_name], EncoderDecoderCache): | ||
| model_kwargs[cache_name] = EncoderDecoderCache( | ||
| model_kwargs[cache_name], # self-attention cache | ||
| if ( |
There was a problem hiding this comment.
update to this branch -- we only want to convert the cache to EncoderDecoderCache here if:
- the user has set custom cache args in
generate - the model is encoder-decoder
- (implicitly, all encoder-decoder models have
past_key_valuesas their cache name)
In all other cases, we delegate cache init to the model itself
| @@ -546,8 +546,10 @@ def prepare_inputs_for_generation( | |||
| model_inputs["cache_position"] = cache_position | |||
|
|
|||
| # 2. Generic cache-dependent input preparation | |||
There was a problem hiding this comment.
Changes in new L549-594 are related to recurrent_gemma: prior to the deletion of the old L2002-2007, generate was preparing a DynamicCache for recurrent_gemma. This cache was never used in forward, but it was inducing the correct behavior in prepare_inputs_for_generation (cache_position-based input slicing)
With the new logic, use_cache=True implies cache_position-based input slicing, even if the model is not using a standard cache.
| past_key_values = ( | ||
| EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config)) | ||
| if encoder_hidden_states is not None | ||
| if encoder_hidden_states is not None or self.config.is_encoder_decoder |
There was a problem hiding this comment.
On encoder-decoder models that we may want to use as decoder-only, we want EncoderDecoderCache in two possible situations:
- [missing] the model is encoder-decoder
- the model is not encoder-decoder, but
encoder_hidden_statespassed (which means we will compute the cross-attention, and thus we should cache it)
| choice_labels, | ||
| ): | ||
| config = copy.deepcopy(config) | ||
| config.is_decoder = True |
There was a problem hiding this comment.
RoFormerForCausalLM won't use the cache if config.is_decoder!=True, and this test tests cache usage 🙃
(A warning was being thrown)
| # build `cache_position` on the fly | ||
| seq_length = inputs["input_ids"].shape[1] | ||
| inputs = self.model._get_initial_cache_position(seq_length, self.model.device, inputs) | ||
| # prepare other inputs |
There was a problem hiding this comment.
whisper has custom generation structure that doesn't follow our code patterns -> one of the issues is that it has stateful LogitsProcessor -> this state-related function uses prepare_inputs_for_generation out of the usual order -> changes related to this PR exposed that it was missing cache_position as an input here
(if we have bandwidth, we should revisit whisper generate to streamline its code)
There was a problem hiding this comment.
(if we have bandwidth, we should revisit whisper generate to streamline its code)
💯, revisiting audio modality generation will be super helpful
| if model_kwargs.get("past_key_values") is not None: | ||
| if model_kwargs.get("past_key_values", None) is not None: |
There was a problem hiding this comment.
Technically not needed, get has None as a default (and I thought ruff was now enforcing this one but apparently not 🤔 too many changes to the ruff rules recently)
| # initialize `past_key_values` | ||
| if use_cache and past_key_values is None: | ||
| past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config)) | ||
|
|
There was a problem hiding this comment.
EncoderDecoderCache initialized in a Decoder only module... 🥲🥲🥲 good catch!
zucchini-nlp
left a comment
There was a problem hiding this comment.
Thanks, left one question that seems to be a typo or just my bad understading
| # build `cache_position` on the fly | ||
| seq_length = inputs["input_ids"].shape[1] | ||
| inputs = self.model._get_initial_cache_position(seq_length, self.model.device, inputs) | ||
| # prepare other inputs |
There was a problem hiding this comment.
(if we have bandwidth, we should revisit whisper generate to streamline its code)
💯, revisiting audio modality generation will be super helpful
| model_inputs["cache_position"] = cache_position | ||
|
|
||
| # 2. Generic cache-dependent input preparation | ||
| use_cache = kwargs.get("use_cache", False) or getattr(self.config, "use_cache", False) |
There was a problem hiding this comment.
this will result in True even when the user-kwargs set caching as False. IIUC we just want to give priority to user-defined cache and not assume that caching is used whenever it is set to True in any of the places
There was a problem hiding this comment.
@zucchini-nlp that's a good catch, user-defined generate kwargs >> config values!
Will update accordingly
| use_cache = kwargs.get("use_cache", False) or getattr(self.config, "use_cache", False) | ||
| if past_key_values is not None: | ||
| model_inputs["past_key_values"] = past_key_values | ||
| if past_key_values is None or use_cache: |
There was a problem hiding this comment.
maybe i am missing smth, do we apply cache slicing when the past_key_values is None? Looks not intuitive from first sight, so let's add a comment explaining why
There was a problem hiding this comment.
stateful models like recurrent_gemma assume that slicing happens, but don't have a Cache cache -- will add a comment :)
There was a problem hiding this comment.
so they don't have a Cache object and also if use_cache does not catch those cases?
|
[For maintainers] Suggested jobs to run (before merge) run-slow: bart, bert, bert_generation, bigbird_pegasus, blenderbot, blenderbot_small, camembert, data2vec, electra, ernie, fsmt, kosmos2, marian, mbart, mvp, pegasus |
What does this PR do?
See PR title.
Now that all traces of legacy caches were removed, we can trust the model to initialize its own cache! This means we no longer need to set
cache_implementation="xxx"defaults in new models, assuming the model's forward pass defaults to the right cache class.Also fixes related bugs, uncovered by not feeding a cache to the model.