🚨🚨 GPT2Model StaticCache support#35761
Conversation
278bcf7 to
dedb154
Compare
ArthurZucker
left a comment
There was a problem hiding this comment.
Thanks for the PR!
Not entirely sure it's worth adding as GPT2 is a super small model, not super optimized anymore, and fairly old so the amount of work is a bit high...
Let's make sure we test cross attetnion path with kv cache as I am not even sure it was supported before
| past_key_value: Optional[Cache] = None, | ||
| cache_position: Optional[torch.LongTensor] = None, |
There was a problem hiding this comment.
the big issue with this is that we are breking backward compatibility for people who use layer_past. We need to deprecate layer_past!
There was a problem hiding this comment.
added @deprecate_kwarg decorator to this forward and 3 more (incl in GPT2Model). Noted that it is an inner model class for attention or inner block, and not affecting the external model interface.
There was a problem hiding this comment.
Noted that it is an inner model class for attention or inner block, and not affecting the external model interface.
We agree in theory, but in practice we've broken many projects by changing the name of internal attributes and arguments. Better safe than sorry 🤗
There was a problem hiding this comment.
I assume that @deprecate_kwarg is sufficient here.
@ArthurZucker , Let me know if anything else is required here.
| # based on pattern from src/transformers/models/whisper/modeling_whisper.py::WhisperDecoder | ||
| return_legacy_cache = False | ||
| if use_cache: | ||
| if past_key_values is not None: | ||
| if isinstance(past_key_values, Cache): | ||
| if self.config.add_cross_attention and not isinstance(past_key_values, EncoderDecoderCache): | ||
| past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) | ||
| elif not isinstance(past_key_values, Cache): | ||
| return_legacy_cache = True | ||
| logger.warning_once( | ||
| "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.49.0. " | ||
| "You should pass an instance of `Cache` instead, e.g. " | ||
| "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`." | ||
| ) | ||
| if self.config.add_cross_attention: | ||
| past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) | ||
| else: | ||
| past_key_values = DynamicCache.from_legacy_cache(past_key_values) | ||
| elif past_key_values is None: | ||
| return_legacy_cache = True | ||
| logger.warning_once( | ||
| "Passing `use_cache=True` and `past_key_values=None` will is produce cache output in legacy format. " | ||
| "This behavior is deprecated and will be changed in Transformers v4.49.0. " | ||
| "To obtain output past_key_values as `Cache` instance you should pass an instance of `Cache` instead, e.g. " | ||
| "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`." | ||
| ) | ||
| if self.config.add_cross_attention: | ||
| past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) | ||
| else: | ||
| past_key_values = DynamicCache() |
There was a problem hiding this comment.
From the look of it, we are adding quite a complex code, which I am not super fan of.
Let's go with this for now, but would be nice to have a single warning to just say the one or the other is deprecated. This as is is not super readable and you have too many code pathes, when you should have:
past_key value is None -> create DynamicCache
past_key_value is not None -> convert to Dynamic cache (not even sure that cross attention cache was even supported)
add_cross_attention -> create EncodeDDecoderCache with past_key_value and a new dynamic cache
There was a problem hiding this comment.
I simplified this logic according to our outline
5cd37a9 to
79821cd
Compare
I agree in principle, but my friends use it for Tortoise text-to-speech and intend to compile it (with modifications for static shapes) to accelerate.
I made effort to patch the cross-attention parts of the code as well. The relevant tests seem to pass |
|
@Rocketknight1 @ArthurZucker , could you, please, approve the remaining checks workflows ? |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
GPT2Model StaticCache supportGPT2Model StaticCache support
|
@Rocketknight1, could you kindly give feedback on this PR please |
|
I'm not confident in reviewing this one, so cc @gante for final review! |
|
Hi @poedator . Nice patch! I am trying to enable it on opt model by following your codes if you don't mind! |
hi, @jiqing-feng, Good luck in updating OPT! It should be pretty straightforward. Also look at new Cache class -related code is done for Llama. (1) this is where I borrowed from (2) it may be closer to OPT than gpt2. |
|
sorry @poedator having a look! |
gante
left a comment
There was a problem hiding this comment.
If the slow tests for BERT are green, LGTM 🤗
(otherwise, feel free to ping me on slack so we can quickly sort them)
|
run slow: bert |
|
(the comment above should have triggered slow tests in our CI, testing this workflow :) ) EDIT: https://github.com/huggingface/transformers/actions/runs/13306553429 |
|
This comment contains run-slow, running the specified jobs: This comment contains run-slow, running the specified jobs: models: ['models/bert'] |
|
this is a friendly ping to @Rocketknight1 |
|
Hey! I'm not actually the core maintainer, but if @gante is happy then pinging @ArthurZucker @Cyrilvallez for core maintainer review |
|
I'll have a look this afternoon! |
Cyrilvallez
left a comment
There was a problem hiding this comment.
Ok, here is my review! I only reviewed gpt2, as the other model seem to only be copy-pasted of the same changes (minus some parts), but of course the relevant comments apply as well! 🤗
Very nice work overall, congrats! Just a few nits, and an issue about the 4d masking but otherwise all good 👌
|
|
||
| return attn_output, attn_weights | ||
|
|
||
| @deprecate_kwarg("layer_past", new_name="past_key_value", version="4.49.0", raise_if_both_names=True) |
There was a problem hiding this comment.
version should be 4.51 here! (4.50 for this to be released, 4.51 so that people can adjust until next version)
|
|
||
| is_cross_attention = encoder_hidden_states is not None | ||
| is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention | ||
| self.is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention |
There was a problem hiding this comment.
Why make it an attribute of the class?
There was a problem hiding this comment.
this was an error, undoing
|
|
||
| self.mlp = GPT2MLP(inner_dim, config) | ||
|
|
||
| @deprecate_kwarg("layer_past", new_name="past_key_value", version="4.49.0", raise_if_both_names=True) |
|
|
||
| outputs = (attn_output, present) | ||
| outputs = (attn_output,) | ||
| if output_attentions: | ||
| outputs += (attn_weights,) |
There was a problem hiding this comment.
Here, as we are breaking it anyway, we can drop output_attentions keyword explicitly, and always return it as we do for other models we refactored. This makes it much easier to read the code (hard to understand when we do output[1], output[2] later on based on the keywords in GPT2Block).
Then let's use the keyword only in the Block (see Llama if this is unclear)
There was a problem hiding this comment.
we may not be able to drop output_attentions here entirely, since it drives choice between sdpa and eager here:
I'd much prefer to keep this PR focused on Cache class support, and leave additional refactoring for later.
There was a problem hiding this comment.
Ha yes sorry, I meant to grab the arg from kwargs, e.g. kwargs.get("output_attentions", False). But anyway, let's always return both, independently of the arg in the Attention: return attn_output, attn_weights. Then use the arg for the return only in Block
There was a problem hiding this comment.
Got you now. Moved all output_attentions logic to the Block
| attn_outputs = self.attn( | ||
| hidden_states, | ||
| layer_past=layer_past, | ||
| past_key_value=past_key_value, | ||
| cache_position=cache_position, | ||
| attention_mask=attention_mask, | ||
| head_mask=head_mask, | ||
| use_cache=use_cache, |
There was a problem hiding this comment.
Then here I'd like to see both outputs explicitly, e.g.: attn_output, attn_weights = self.attn(...), and remove the comment about output shapes 🤗
There was a problem hiding this comment.
same as above, output_attentions drives choice between sdpa and eager
There was a problem hiding this comment.
Sure but let's return both anyway!
| attention_mask = ( | ||
| attention_mask.view(batch_size, -1) if attention_mask is not None and attention_mask.ndim < 4 else None | ||
| ) |
There was a problem hiding this comment.
Here, if someone passes a 4d mask it's going to evaluate to None and later be overwritten!! If you want to allow passing 4d mask as well in this PR, this is wrong!
There was a problem hiding this comment.
agreed. Rewritten this part
| None, | ||
| causal_mask, |
There was a problem hiding this comment.
Here i'd rather you pass along the cache_positions even if it is not used, for clarity
There was a problem hiding this comment.
converted to kwargs in a890899 for clarity
There was a problem hiding this comment.
Nop, unfortunately keyword arguments will not work with checkpointing! So they should still be passed as before as simple args, but just send (past_key_values, cache_positions instead of None, None 🤗
| next_cache = past_key_values if use_cache else None | ||
| if return_legacy_cache: | ||
| next_cache = past_key_values.to_legacy_cache() |
There was a problem hiding this comment.
let's keep calling them past_key_value here, next_cache does not make sense anymore
| past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, | ||
| cache_position: Optional[torch.LongTensor] = None, |
There was a problem hiding this comment.
Did you make sure they will be passed along by generate's prepare_inputd_for_generation?
Also, you can safely remove all _reorder_cache() methods il all classes
There was a problem hiding this comment.
tried to removе reorder_cache, it was OK for GPT2, but caused error because of consistency connection with CLVP model.
cache_position seems to pass thru prepare_inputd_for_generation - observed that in tests like test_generate_with_static_cache
There was a problem hiding this comment.
You can simply remove the Copied from line for reorder_cache in clvp model 😉
| if self.config.add_cross_attention and not isinstance(past_key_values, EncoderDecoderCache): | ||
| past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) |
There was a problem hiding this comment.
From what I can see, gpt2 used to never use cache for cross-attention, therefore never returning it. So this would need a special treatment as well when returning the kv (only return self_attention part, and in legacy format)
There was a problem hiding this comment.
Originally I borrowed code from modeling_whisper.py
later I simplified it on suggestion from Arthur.
Glad to change it again, but please, give more specific instructions on how you see it and which tests it should pass
There was a problem hiding this comment.
It should basically pass all the tests haha 😉 Here I just meant it needs a special treatment as well when returning the cache. So basically flag return_legacy_cache should be active, and when returning do:
if return_legacy_cache:
past_key_values = past_key_values.self_attention_cache.to_legacy_cache() if add_cross_attentions else past_key_values.to_legacy_cache()There was a problem hiding this comment.
fixed it now, all GPT2 tests pass OK
|
Also, did you run slow tests by any chance? 🤗 |
I did run the SLOW tests, and all of them passed before I asked for the review. There were couple of exceptions in some other models' tests which did not appear in CI upon submission. Will run the SLOW tests again and attach the output link here before calling for another review. I addressed most of your comments. Please see if you could resolve some or all of them, or give me more guidance |
Hi, @Cyrilvallez! Hope that the merge is close. You were correct about inconsistencies with the |
Cyrilvallez
left a comment
There was a problem hiding this comment.
Hey @poedator! Thanks a lot for updating! Here are some final thoughts! Also, I still see some logic differences between both models, most notable the mask handling! Let's make sure both models are functionyla equivalent 🤗
I fixed what you directly suggested and also made the cache init sections between the 2 models more similar. As for mask handling and the rest of logic differences -- I am less familiar with this exotic |
Cyrilvallez
left a comment
There was a problem hiding this comment.
Oh sorry, I thought you wanted to add the same changes and functionalities to decision_transformer as well. If it's not the case, we can just remove the flag saying that it supports static cache (as we would need correct masking for it), then it should be good to go! 🤗
| is_parallelizable = True | ||
| supports_gradient_checkpointing = True | ||
| _supports_cache_class = True | ||
| _supports_static_cache = True |
There was a problem hiding this comment.
Alright, then it should not support static cache!
There was a problem hiding this comment.
changed it this way:
_supports_cache_class = True
_supports_static_cache = False
Yes, I was focused on adding support of StaticCache to GPT2, as it is a rather popular model. I will gladly leave the last bit of the Decision Transformer challenge to whoever knows that model better. Please see if this is OK now. |
Cyrilvallez
left a comment
There was a problem hiding this comment.
Alright, LGTM!! 🤗🚀 Thanks a lot for the addition, and once again super sorry for the time it took! Great work 🤗
All tests (slow as well) are passing so it's good to go!
@Cyrilvallez @ArthurZucker @gante |
GPT2Model StaticCache supportGPT2Model StaticCache support
* initial GPT2 changes * causal_mask support * return_legacy_cache * cleanup * fix1 * outputs shape fixes * gpt2 return fix * pkv, attn fixes * fix dual_head * is_causal arg fix * decision transformer updated * style fix * batch_size from inputs_embeds * DecisionTransformerModel fixes * cross-attn support + cache warning * x-attn @Decision * EDCache proper init * simplified logic in `if use_cache:` for GPT2Model * @deprecate_kwarg for DecisionTr attn fwd * @deprecate_kwarg in gpt2 * deprecation version updated to 4.51 * kwargs in gradient_checkpointing_fn * rename next_cache to past_key_values * attention_mask prep * +cache_position in GPT2DoubleHeadsModel * undo kwargs in gradient checkpointing * moved up `if self.gradient_checkpointing` * consistency in decision_transformer * pastkv, cache_pos in grad_checkpt args * rm _reorder_cache * output_attentions streamlined * decision_transformer consistency * return_legacy_cache improved * ClvpForCausalLM used for legacy cache test now * is_causal fixed * attn_output cleanup * consistency @ decision_transformer * Updated deprecation notice version to 4.52 * upd deprecation * consistent legacy cache code in decision transformers\ * next_cache -> past_kv in decision_tr * cache support flags in decision_transf * rm legacy cache warning * consistency in cache init for decision transf * no Static Cache for Decision Transformer --------- Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>
I copied _update_causal_mask() and _prepare_4d_causal_attention_mask_with_cache_position() from
LlamaModelsome tests are still failing:
both may be linked to attention implementations. So far I was enable to figure out the reasons for failures. I'd appreciate advice or help from the maintainers.
cc: @gante