Skip to content

🚨🚨 GPT2Model StaticCache support#35761

Merged
Cyrilvallez merged 46 commits intohuggingface:mainfrom
poedator:gpt_static
Apr 24, 2025
Merged

🚨🚨 GPT2Model StaticCache support#35761
Cyrilvallez merged 46 commits intohuggingface:mainfrom
poedator:gpt_static

Conversation

@poedator
Copy link
Contributor

@poedator poedator commented Jan 18, 2025

I copied _update_causal_mask() and _prepare_4d_causal_attention_mask_with_cache_position() from LlamaModel

some tests are still failing:

  1. tests/models/gpt2/test_modeling_gpt2.py::GPT2ModelTest::test_custom_4d_attention_mask
  2. test_modeling_vision_encoder_decoder.py::VIT2GPT2Test::test_save_and_load_from_pretrained

both may be linked to attention implementations. So far I was enable to figure out the reasons for failures. I'd appreciate advice or help from the maintainers.

cc: @gante

@poedator poedator mentioned this pull request Jan 18, 2025
@poedator poedator force-pushed the gpt_static branch 2 times, most recently from 278bcf7 to dedb154 Compare January 18, 2025 13:30
@poedator poedator marked this pull request as ready for review January 18, 2025 17:43
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!
Not entirely sure it's worth adding as GPT2 is a super small model, not super optimized anymore, and fairly old so the amount of work is a bit high...

Let's make sure we test cross attetnion path with kv cache as I am not even sure it was supported before

Comment on lines +265 to +268
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the big issue with this is that we are breking backward compatibility for people who use layer_past. We need to deprecate layer_past!

Copy link
Contributor Author

@poedator poedator Jan 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added @deprecate_kwarg decorator to this forward and 3 more (incl in GPT2Model). Noted that it is an inner model class for attention or inner block, and not affecting the external model interface.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noted that it is an inner model class for attention or inner block, and not affecting the external model interface.

We agree in theory, but in practice we've broken many projects by changing the name of internal attributes and arguments. Better safe than sorry 🤗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume that @deprecate_kwarg is sufficient here.
@ArthurZucker , Let me know if anything else is required here.

Comment on lines +830 to +859
# based on pattern from src/transformers/models/whisper/modeling_whisper.py::WhisperDecoder
return_legacy_cache = False
if use_cache:
if past_key_values is not None:
if isinstance(past_key_values, Cache):
if self.config.add_cross_attention and not isinstance(past_key_values, EncoderDecoderCache):
past_key_values = EncoderDecoderCache(past_key_values, DynamicCache())
elif not isinstance(past_key_values, Cache):
return_legacy_cache = True
logger.warning_once(
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.49.0. "
"You should pass an instance of `Cache` instead, e.g. "
"`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`."
)
if self.config.add_cross_attention:
past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
else:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
elif past_key_values is None:
return_legacy_cache = True
logger.warning_once(
"Passing `use_cache=True` and `past_key_values=None` will is produce cache output in legacy format. "
"This behavior is deprecated and will be changed in Transformers v4.49.0. "
"To obtain output past_key_values as `Cache` instance you should pass an instance of `Cache` instead, e.g. "
"`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`."
)
if self.config.add_cross_attention:
past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
else:
past_key_values = DynamicCache()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the look of it, we are adding quite a complex code, which I am not super fan of.
Let's go with this for now, but would be nice to have a single warning to just say the one or the other is deprecated. This as is is not super readable and you have too many code pathes, when you should have:
past_key value is None -> create DynamicCache
past_key_value is not None -> convert to Dynamic cache (not even sure that cross attention cache was even supported)
add_cross_attention -> create EncodeDDecoderCache with past_key_value and a new dynamic cache

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I simplified this logic according to our outline

@poedator poedator force-pushed the gpt_static branch 2 times, most recently from 5cd37a9 to 79821cd Compare January 25, 2025 16:34
@poedator
Copy link
Contributor Author

poedator commented Jan 25, 2025

Not entirely sure it's worth adding ...

I agree in principle, but my friends use it for Tortoise text-to-speech and intend to compile it (with modifications for static shapes) to accelerate.

Let's make sure we test cross attention path with kv cache as I am not even sure it was supported before

I made effort to patch the cross-attention parts of the code as well. The relevant tests seem to pass

@poedator
Copy link
Contributor Author

poedator commented Jan 28, 2025

@Rocketknight1 @ArthurZucker , could you, please, approve the remaining checks workflows ?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@poedator poedator changed the title [WiP] GPT2Model StaticCache support GPT2Model StaticCache support Jan 30, 2025
@poedator
Copy link
Contributor Author

poedator commented Feb 3, 2025

@Rocketknight1, could you kindly give feedback on this PR please

@Rocketknight1
Copy link
Member

I'm not confident in reviewing this one, so cc @gante for final review!

@jiqing-feng
Copy link
Contributor

jiqing-feng commented Feb 8, 2025

Hi @poedator . Nice patch! I am trying to enable it on opt model by following your codes if you don't mind!

@poedator
Copy link
Contributor Author

poedator commented Feb 8, 2025

Hi @poedator . Nice patch! I am trying to enable it on opt model by following your codes if you don't mind!

hi, @jiqing-feng, Good luck in updating OPT! It should be pretty straightforward. Also look at new Cache class -related code is done for Llama. (1) this is where I borrowed from (2) it may be closer to OPT than gpt2.

@ArthurZucker
Copy link
Collaborator

sorry @poedator having a look!

Copy link
Contributor

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the slow tests for BERT are green, LGTM 🤗

(otherwise, feel free to ping me on slack so we can quickly sort them)

@gante
Copy link
Contributor

gante commented Feb 13, 2025

run slow: bert

@gante
Copy link
Contributor

gante commented Feb 13, 2025

(the comment above should have triggered slow tests in our CI, testing this workflow :) )

EDIT: https://github.com/huggingface/transformers/actions/runs/13306553429

@github-actions
Copy link
Contributor

This comment contains run-slow, running the specified jobs: This comment contains run-slow, running the specified jobs:

models: ['models/bert']
quantizations: [] ...

@poedator
Copy link
Contributor Author

this is a friendly ping to @Rocketknight1
now that @gante approved, could you finalise you review, please?

@Rocketknight1
Copy link
Member

Hey! I'm not actually the core maintainer, but if @gante is happy then pinging @ArthurZucker @Cyrilvallez for core maintainer review

@Cyrilvallez
Copy link
Member

I'll have a look this afternoon!

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, here is my review! I only reviewed gpt2, as the other model seem to only be copy-pasted of the same changes (minus some parts), but of course the relevant comments apply as well! 🤗

Very nice work overall, congrats! Just a few nits, and an issue about the 4d masking but otherwise all good 👌


return attn_output, attn_weights

@deprecate_kwarg("layer_past", new_name="past_key_value", version="4.49.0", raise_if_both_names=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

version should be 4.51 here! (4.50 for this to be released, 4.51 so that people can adjust until next version)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addressed in 1372505


is_cross_attention = encoder_hidden_states is not None
is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention
self.is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why make it an attribute of the class?

Copy link
Contributor Author

@poedator poedator Feb 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was an error, undoing


self.mlp = GPT2MLP(inner_dim, config)

@deprecate_kwarg("layer_past", new_name="past_key_value", version="4.49.0", raise_if_both_names=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines 351 to 357

outputs = (attn_output, present)
outputs = (attn_output,)
if output_attentions:
outputs += (attn_weights,)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, as we are breaking it anyway, we can drop output_attentions keyword explicitly, and always return it as we do for other models we refactored. This makes it much easier to read the code (hard to understand when we do output[1], output[2] later on based on the keywords in GPT2Block).
Then let's use the keyword only in the Block (see Llama if this is unclear)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we may not be able to drop output_attentions here entirely, since it drives choice between sdpa and eager here:

if self.config._attn_implementation == "sdpa" and (output_attentions or head_mask is not None):

I'd much prefer to keep this PR focused on Cache class support, and leave additional refactoring for later.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ha yes sorry, I meant to grab the arg from kwargs, e.g. kwargs.get("output_attentions", False). But anyway, let's always return both, independently of the arg in the Attention: return attn_output, attn_weights. Then use the arg for the return only in Block

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got you now. Moved all output_attentions logic to the Block

Comment on lines 410 to 413
attn_outputs = self.attn(
hidden_states,
layer_past=layer_past,
past_key_value=past_key_value,
cache_position=cache_position,
attention_mask=attention_mask,
head_mask=head_mask,
use_cache=use_cache,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then here I'd like to see both outputs explicitly, e.g.: attn_output, attn_weights = self.attn(...), and remove the comment about output shapes 🤗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above, output_attentions drives choice between sdpa and eager

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure but let's return both anyway!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

Comment on lines +873 to +875
attention_mask = (
attention_mask.view(batch_size, -1) if attention_mask is not None and attention_mask.ndim < 4 else None
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, if someone passes a 4d mask it's going to evaluate to None and later be overwritten!! If you want to allow passing 4d mask as well in this PR, this is wrong!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed. Rewritten this part

Comment on lines +938 to +1037
None,
causal_mask,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here i'd rather you pass along the cache_positions even if it is not used, for clarity

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

converted to kwargs in a890899 for clarity

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nop, unfortunately keyword arguments will not work with checkpointing! So they should still be passed as before as simple args, but just send (past_key_values, cache_positions instead of None, None 🤗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed, as args

Comment on lines +979 to +981
next_cache = past_key_values if use_cache else None
if return_legacy_cache:
next_cache = past_key_values.to_legacy_cache()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's keep calling them past_key_value here, next_cache does not make sense anymore

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done in 70b61c7

Comment on lines 1187 to +1290
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
cache_position: Optional[torch.LongTensor] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you make sure they will be passed along by generate's prepare_inputd_for_generation?

Also, you can safely remove all _reorder_cache() methods il all classes

Copy link
Contributor Author

@poedator poedator Feb 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tried to removе reorder_cache, it was OK for GPT2, but caused error because of consistency connection with CLVP model.

cache_position seems to pass thru prepare_inputd_for_generation - observed that in tests like test_generate_with_static_cache

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can simply remove the Copied from line for reorder_cache in clvp model 😉

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines +854 to +961
if self.config.add_cross_attention and not isinstance(past_key_values, EncoderDecoderCache):
past_key_values = EncoderDecoderCache(past_key_values, DynamicCache())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I can see, gpt2 used to never use cache for cross-attention, therefore never returning it. So this would need a special treatment as well when returning the kv (only return self_attention part, and in legacy format)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Originally I borrowed code from modeling_whisper.py
later I simplified it on suggestion from Arthur.
Glad to change it again, but please, give more specific instructions on how you see it and which tests it should pass

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should basically pass all the tests haha 😉 Here I just meant it needs a special treatment as well when returning the cache. So basically flag return_legacy_cache should be active, and when returning do:

if return_legacy_cache:
        past_key_values = past_key_values.self_attention_cache.to_legacy_cache() if add_cross_attentions else past_key_values.to_legacy_cache()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed it now, all GPT2 tests pass OK

@Cyrilvallez
Copy link
Member

Also, did you run slow tests by any chance? 🤗

@poedator
Copy link
Contributor Author

@Cyrilvallez

Also, did you run slow tests by any chance? 🤗

I did run the SLOW tests, and all of them passed before I asked for the review. There were couple of exceptions in some other models' tests which did not appear in CI upon submission.

Will run the SLOW tests again and attach the output link here before calling for another review.

I addressed most of your comments. Please see if you could resolve some or all of them, or give me more guidance

@poedator
Copy link
Contributor Author

Hey @poedator! Here is a long due review. [...]

Hi, @Cyrilvallez! Hope that the merge is close. You were correct about inconsistencies with the decision_transformer. I fixed those, updated deprecation, rebased to fresh main. Let's see if this is good enough

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @poedator! Thanks a lot for updating! Here are some final thoughts! Also, I still see some logic differences between both models, most notable the mask handling! Let's make sure both models are functionyla equivalent 🤗

@poedator
Copy link
Contributor Author

Hey @poedator! Thanks a lot for updating! Here are some final thoughts! Also, I still see some logic differences between both models, most notable the mask handling! Let's make sure both models are functionyla equivalent 🤗

I fixed what you directly suggested and also made the cache init sections between the 2 models more similar.

As for mask handling and the rest of logic differences -- I am less familiar with this exotic decision transformer model, so I'd be glad if you point me to the specific failing tests so I'd debug the problem. Or feel free to edit this PR directly and take 50% of the glory.
the CI tests seem to fail for unrelated reasons.

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh sorry, I thought you wanted to add the same changes and functionalities to decision_transformer as well. If it's not the case, we can just remove the flag saying that it supports static cache (as we would need correct masking for it), then it should be good to go! 🤗

is_parallelizable = True
supports_gradient_checkpointing = True
_supports_cache_class = True
_supports_static_cache = True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, then it should not support static cache!

Copy link
Contributor Author

@poedator poedator Apr 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed it this way:
_supports_cache_class = True
_supports_static_cache = False

@poedator
Copy link
Contributor Author

Oh sorry, I thought you wanted to add the same changes and functionalities to decision_transformer as well. If it's not the case, we can just remove the flag saying that it supports static cache (as we would need correct masking for it), then it should be good to go! 🤗

Yes, I was focused on adding support of StaticCache to GPT2, as it is a rather popular model. I will gladly leave the last bit of the Decision Transformer challenge to whoever knows that model better. Please see if this is OK now.

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, LGTM!! 🤗🚀 Thanks a lot for the addition, and once again super sorry for the time it took! Great work 🤗
All tests (slow as well) are passing so it's good to go!

@Cyrilvallez Cyrilvallez merged commit 7c62e69 into huggingface:main Apr 24, 2025
18 checks passed
@poedator
Copy link
Contributor Author

Alright, LGTM!! 🤗🚀 ... good to go!

@Cyrilvallez @ArthurZucker @gante
thank you for the valuable inputs and support along the way. Hopefully there are still quite a few people who use GPT2 and who may find this PR useful

@ArthurZucker ArthurZucker changed the title GPT2Model StaticCache support 🚨🚨 GPT2Model StaticCache support Apr 28, 2025
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request May 14, 2025
* initial GPT2 changes

* causal_mask support

* return_legacy_cache

* cleanup

* fix1

* outputs shape fixes

* gpt2 return fix

* pkv, attn fixes

* fix dual_head

* is_causal arg fix

* decision transformer updated

* style fix

* batch_size from inputs_embeds

* DecisionTransformerModel fixes

* cross-attn support + cache warning

* x-attn @Decision

* EDCache proper init

* simplified logic in `if use_cache:` for GPT2Model

* @deprecate_kwarg for DecisionTr attn fwd

* @deprecate_kwarg in gpt2

* deprecation version updated to 4.51

* kwargs in gradient_checkpointing_fn

* rename next_cache to past_key_values

* attention_mask prep

* +cache_position in GPT2DoubleHeadsModel

* undo kwargs in gradient checkpointing

* moved up `if self.gradient_checkpointing`

* consistency in decision_transformer

* pastkv, cache_pos in grad_checkpt args

* rm _reorder_cache

* output_attentions streamlined

* decision_transformer consistency

* return_legacy_cache improved

* ClvpForCausalLM used for legacy cache test now

* is_causal fixed

* attn_output cleanup

* consistency @ decision_transformer

* Updated deprecation notice version to 4.52

* upd deprecation

* consistent legacy cache code in decision transformers\

* next_cache -> past_kv in decision_tr

* cache support flags in decision_transf

* rm legacy cache warning

* consistency in cache init for decision transf

* no Static Cache for Decision Transformer

---------

Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants