Skip to content

Cache: slight change in naming#32421

Merged
zucchini-nlp merged 15 commits intohuggingface:mainfrom
zucchini-nlp:kv-cache
Oct 8, 2024
Merged

Cache: slight change in naming#32421
zucchini-nlp merged 15 commits intohuggingface:mainfrom
zucchini-nlp:kv-cache

Conversation

@zucchini-nlp
Copy link
Member

What does this PR do?

Following #32315 (comment), this PR adds two attributes on all cache classes, is_sliding and is_static. We can now rely on these attr to crop/prepare attention mask.

@zucchini-nlp zucchini-nlp requested a review from gante August 5, 2024 05:40
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you for working on this 💛

One design question about sliding caches

Copy link
Contributor

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, other than the deprecation cycle and the need to double-check attributes 🤗

# TODO: deprecate this function in favor of `cache_position`
raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")

def get_max_length(self) -> Optional[int]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will need a deprecation cycle for get_max_length before we can delete it :P

(in the raised warning, make sure you mention that users should use get_max_cache_shape instead)

"""

is_static = False
is_sliding = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(double-check whether this new attribute is correct, after rebasing. I think we've added new cache classes since your original commit. In the encoder-decoder cache, these attributes should be loaded from the decoder cache)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I ran tests for generation in llama and all compile-static tests in whisper. Rebased main and fix-copies propagated changes to some new models

@ArthurZucker ArthurZucker requested review from ArthurZucker and removed request for LysandreJik October 1, 2024 07:36
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the in depth changes! Not convinced we have to change the name of the function, and also convinced we should have abstraction on what we call. Checking for is static IMO is not a good way forward. It might be a good internal helper, but in the modeling we would prefer not having to check that wdyt?

Also fine for this specific case TBH! Just not sure it's gonna scale

if (
self.config._attn_implementation == "sdpa"
and past_key_values is not None
and not past_key_values.is_static
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not entirely sure we want to add these kind of checks directly in the modeling code:
past_key_values.is_static does not really look specific, but if you allow checking is_sliding then we are gonna have is_encoder_decoder and etc.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aligned with my previous comment where we need a general past_key_value.skip_sdpa_correction() for example

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, maybe we can make a better check to not use isinstance() on cache objects. We are having more varieties of cache now and it is mostly either dynamic or static type of some caching method, so it would be nice to not rely on isinstance() checks in general

We can remove modeling changes in this PR and use the new attributes in specific cases (linked comment from PR description). What I need from this PR currently is the get_max_length that doesn;t return None anymore

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah Also want to avoid .is_static and isinstance etc. Not sure how to abstract best tho!

Okay, let's remove then for now

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okey, done and rebased main


def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states."""
def get_max_cache_shape(self) -> Optional[int]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a specific reason for the name change ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only reason is to be more specific what we mean by max_length since I wanted to add max_length in sliding cache, which currently doesn't have any (returns None). And as comment in code explains it was done with idea that the cache object technically handles infinite amount of tokens, so no max length

But we never want to know how many tokens can cache technically handle, our checks are all about max capacity of particular cache instance

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add this in the comment for example! I was not sure 🤗

sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_length()
if past_key_values is not None and past_key_values.is_static:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

@zucchini-nlp zucchini-nlp changed the title Cache: add class attributes Cache: slight change in naming Oct 7, 2024
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@zucchini-nlp zucchini-nlp merged commit bead0fa into huggingface:main Oct 8, 2024
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
* squash

* codestyle

* Update src/transformers/cache_utils.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* propagate changes to all cache classes

* + whisper

* fix tests

* more fixes

* add deprecation warning

* fix copies

* address comments

* fix mistral also

* these didn't have "copied from"

---------

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants