Add torch.compile for Whisper #30949
Add torch.compile for Whisper #30949zhenglongjiepheonix wants to merge 6 commits intohuggingface:mainfrom
Conversation
| if ( | ||
| is_cross_attention | ||
| and past_key_value is not None | ||
| and past_key_value[0].shape[2] == key_value_states.shape[1] | ||
| ): | ||
| # reuse k,v, cross_attentions | ||
| key_states = past_key_value[0] | ||
| value_states = past_key_value[1] | ||
| elif is_cross_attention: |
There was a problem hiding this comment.
Here according to my understanding, key_states and value_states are once and for all computed based on encoder hidden states when we are doing cross-attention, so we can cache them ideally, but I think these are a little bit different from the existing caches because we don't need to update the cache in every generation step, we just have to do it once in the first step, so we could either try creating another Cache class and in that cases we need to pass in two caches(one for self attention and one for cross attention), even if we can create a cache class to wrap both, we still need to modify the current get_cache logic because it definitely will need more parameters to initiate the cache, and I don't know if the use case is general enough to create this new cache class. Or we can just initiate and update the cross attention kv cache within every layer like recurrent gemma, but this will require manually reset the cache between generations because the current logic in generation seems not considering the case when there is an inherent cache possessed by the model. I am personally for the latter solution, however I think both will need some changes in the generation utils file, and not sure which way might make dynamo or cudagraphs unhappy, what do you think is best @ArthurZucker
There was a problem hiding this comment.
Your understanding is correct: we compute the k/v states for the cross-attention once in the first forward pass, and then save them to cache. We would indeed want to cache the k/v states to avoid re-computing them at every step.
What do you think about the proposed design in this PR? #28931 (comment) Similar to your solution, it uses two cache's: one for the self-attention, and one for the cross-attention. Note that the cache specifics are out-of-date, given the recent changes to the cache API, but the high-level design remains valid!
There was a problem hiding this comment.
Yeah, it would be great if two StaticCache would work, but the issue is in that case we can not tell whether we are in the first generation step or not because the shape of cache now is always fixed to hold the maximum generation size, and use get_seq_length(to see if we have processed any tokens yet during tracing) in branch condition will indeed cause graph breaks
There was a problem hiding this comment.
my current solution is use two StaticCache and add another flag in attention layer to mark whether we are doing the first generation step, this will cause a recompile in the second step just like in llama and mistral but should work fine for the subsequent steps, but it's still nasty because it breaks the current stand-alone cache design, maybe I should indeed create another new OneShotCache class which hides the flag inside? cc @ArthurZucker
@gante @sanchit-gandhi
There was a problem hiding this comment.
The current solution utilizes a new OneShotCache, and a tuple of two caches is expected for conditional generation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| if self.config.is_encoder_decoder: | ||
| # manually set another cache for cross attention | ||
| encoder_outputs = model_kwargs["encoder_outputs"][0] | ||
| model_kwargs["past_key_values"] = ( | ||
| model_kwargs["past_key_values"], | ||
| self._get_cache("one_shot", encoder_outputs.shape[0], encoder_outputs.shape[1], '_cross_attn_cache') | ||
| ) |
There was a problem hiding this comment.
I wonder if there is a better way to do this, because in the encoder-decoder scenario we need a tuple of two caches here according to the current design, but this seems hardcode and easy to break
| if is_cross_attention and ((isinstance(past_key_value, DynamicCache) and self.layer_idx < len(past_key_value.key_cache)) | ||
| or isinstance(past_key_value, OneShotStaticCache) and past_key_value.query_cache_filled_status(self.layer_idx)): | ||
| key_states = past_key_value.key_cache[self.layer_idx] | ||
| value_states = past_key_value.value_cache[self.layer_idx] | ||
| need_update_cache = False |
There was a problem hiding this comment.
Actually we can always use OneShotStaticCache for kv cache here for simplicity, beccause using Dynamic Cache won't give us memory benefits on cross-atten kv cache
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
this PR adds torch.compile support for whisper model which is encoder-decoder architecture