Cache: slight change in naming#32421
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
gante
left a comment
There was a problem hiding this comment.
LGTM, thank you for working on this 💛
One design question about sliding caches
d59a60d to
d4c573f
Compare
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
gante
left a comment
There was a problem hiding this comment.
LGTM, other than the deprecation cycle and the need to double-check attributes 🤗
| # TODO: deprecate this function in favor of `cache_position` | ||
| raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") | ||
|
|
||
| def get_max_length(self) -> Optional[int]: |
There was a problem hiding this comment.
We will need a deprecation cycle for get_max_length before we can delete it :P
(in the raised warning, make sure you mention that users should use get_max_cache_shape instead)
src/transformers/cache_utils.py
Outdated
| """ | ||
|
|
||
| is_static = False | ||
| is_sliding = False |
There was a problem hiding this comment.
(double-check whether this new attribute is correct, after rebasing. I think we've added new cache classes since your original commit. In the encoder-decoder cache, these attributes should be loaded from the decoder cache)
There was a problem hiding this comment.
yes, I ran tests for generation in llama and all compile-static tests in whisper. Rebased main and fix-copies propagated changes to some new models
ArthurZucker
left a comment
There was a problem hiding this comment.
Thanks for the in depth changes! Not convinced we have to change the name of the function, and also convinced we should have abstraction on what we call. Checking for is static IMO is not a good way forward. It might be a good internal helper, but in the modeling we would prefer not having to check that wdyt?
Also fine for this specific case TBH! Just not sure it's gonna scale
| if ( | ||
| self.config._attn_implementation == "sdpa" | ||
| and past_key_values is not None | ||
| and not past_key_values.is_static |
There was a problem hiding this comment.
I am not entirely sure we want to add these kind of checks directly in the modeling code:
past_key_values.is_static does not really look specific, but if you allow checking is_sliding then we are gonna have is_encoder_decoder and etc.
There was a problem hiding this comment.
Aligned with my previous comment where we need a general past_key_value.skip_sdpa_correction() for example
There was a problem hiding this comment.
Hmm, maybe we can make a better check to not use isinstance() on cache objects. We are having more varieties of cache now and it is mostly either dynamic or static type of some caching method, so it would be nice to not rely on isinstance() checks in general
We can remove modeling changes in this PR and use the new attributes in specific cases (linked comment from PR description). What I need from this PR currently is the get_max_length that doesn;t return None anymore
There was a problem hiding this comment.
Yeah Also want to avoid .is_static and isinstance etc. Not sure how to abstract best tho!
Okay, let's remove then for now
There was a problem hiding this comment.
okey, done and rebased main
|
|
||
| def get_max_length(self) -> Optional[int]: | ||
| """Returns the maximum sequence length of the cached states.""" | ||
| def get_max_cache_shape(self) -> Optional[int]: |
There was a problem hiding this comment.
is there a specific reason for the name change ?
There was a problem hiding this comment.
The only reason is to be more specific what we mean by max_length since I wanted to add max_length in sliding cache, which currently doesn't have any (returns None). And as comment in code explains it was done with idea that the cache object technically handles infinite amount of tokens, so no max length
But we never want to know how many tokens can cache technically handle, our checks are all about max capacity of particular cache instance
There was a problem hiding this comment.
Let's add this in the comment for example! I was not sure 🤗
| sequence_length = input_tensor.shape[1] | ||
| if using_static_cache: | ||
| target_length = past_key_values.get_max_length() | ||
| if past_key_values is not None and past_key_values.is_static: |
* squash * codestyle * Update src/transformers/cache_utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * propagate changes to all cache classes * + whisper * fix tests * more fixes * add deprecation warning * fix copies * address comments * fix mistral also * these didn't have "copied from" --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
What does this PR do?
Following #32315 (comment), this PR adds two attributes on all cache classes,
is_slidingandis_static. We can now rely on these attr to crop/prepare attention mask.