Remove ambiguous padding_mask and instead use a 2D->4D Attn Mask Mapper#26792
Remove ambiguous padding_mask and instead use a 2D->4D Attn Mask Mapper#26792patrickvonplaten merged 28 commits intomainfrom
padding_mask and instead use a 2D->4D Attn Mask Mapper#26792Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
fxmarty
left a comment
There was a problem hiding this comment.
Thank you! I am happy with it, just wondering whether changing the attention_mask input from being 4D to 2D in LlamaDecoderLayer & LlamaAttention is considered a breaking change or not.
To me they are internal classes and thus changing the format of the attention_mask is ok. @LysandreJik @younesbelkada @ArthurZucker what do you think? |
younesbelkada
left a comment
There was a problem hiding this comment.
Thanks a lot this looks much cleaner indeed!
Regarding your comment as those are internal classes I think it is fine, as long as the changes are explicitly detailed in the next release notes
ArthurZucker
left a comment
There was a problem hiding this comment.
Makes a lot of sense indeed, we have had the past logic for quite a while and an update is welcome!
| from .configuration_llama import LlamaConfig | ||
|
|
||
|
|
||
| class AttentionMask2DTo4D: |
There was a problem hiding this comment.
Do we plan to move this to the modelling utils or is this gonna be here for all models?
There was a problem hiding this comment.
Either #Copied from or we move it to a utils file. Both would work for me
|
|
||
| class LlamaDecoderLayer(nn.Module): | ||
| def __init__(self, config: LlamaConfig): | ||
| def __init__(self, config: LlamaConfig, mask_converter=None): |
There was a problem hiding this comment.
If we support passing the mask converter here but not in the parent classes it's kind of pointless no?
Wondering which one would be the best:
- Pass the mask converter class to all classes
- Only have it in the attention layer, controlled with a
MASK_CONVERTER = {"default": AttentionMask2DTo4D}and just in the attention layer doself.mask_converter = MASK_CONVERTER[config.mask_converter]with the attribute added to the config common?
(naming can be improve for sure!)
There was a problem hiding this comment.
Sorry I don't fully understand this
There was a problem hiding this comment.
To begin with I would not make the "attention cache" a class that the user plays around with, but instead use it as an internal convenience class that doesn't sacrifice speed but helps readability.
Since the same instance of the class needs to be shared among the different layers, we need to instantiate it at a ...Model level and then let it trickle down to the respective attention classes.
There was a problem hiding this comment.
I see sorry realised that if you want to share the same cached mask you gotta pass it, ignore my comment
Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>
padding_mask and instead use an AttentionMaskConverter class
padding_mask and instead use an AttentionMaskConverter class padding_mask and instead use a 2D->4D Attn Mask Mapper with Cache
| dropout_p=dropout, | ||
| softmax_scale=softmax_scale, | ||
| causal=True, | ||
| causal=self.attention_mask_cache.is_causal, |
There was a problem hiding this comment.
This allows us to easily copy-paste this function to non-causal attention layers (BERT)
padding_mask and instead use a 2D->4D Attn Mask Mapper with Cachepadding_mask and instead use a 2D->4D Attn Mask Mapper / Cache
|
This PR should help make the following PRs nicer / cleaner: |
| padding_mask: Optional[torch.LongTensor] = None, | ||
| ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: | ||
| """ | ||
| Args: |
There was a problem hiding this comment.
@patrickvonplaten Make sure to change the docstring line 728 (of this branch):
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
ArthurZucker
left a comment
There was a problem hiding this comment.
LGTM, we might have to keep some logic to pop the padding mask for 1 release for BC. let's do a deprecation cycle no?
| self.embed_tokens = value | ||
|
|
||
| # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask | ||
| # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask |
There was a problem hiding this comment.
Let's copy from original source
padding_mask and instead use a 2D->4D Attn Mask Mapper / Cachepadding_mask and instead use a 2D->4D Attn Mask Mapper
|
Update: I removed all the cache logic and instead just pass the attention_mask in the format that's needed. This is cleaner than caching tensors according to their shape, memory_id, etc... All the benefits are kept including much improved readability and comprehensive attention mask class that can be copied / re-used by other models. All tests pass just like before! |
examples/research_projects/jax-projects/big_bird/bigbird_flax.py
Outdated
Show resolved
Hide resolved
|
Then I am not sure what is the point of the class AttnMaskConverter? By the way, for SDPA, ideally we need both the information of 1/ is padding used 2/ transformers custom attention mask. This is because if custom masking is not used, we may dispatch on flash attention. So passing only a 4D mask for SDPA is suboptimal in my opinion. Or I could just always pass the 4D attention mask to SDPA, but that kind of defeats the point given that dispatch to FA is then impossible. |
| hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` | ||
| attention_mask (`torch.FloatTensor`, *optional*): attention mask of size | ||
| `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. | ||
| `(batch, sequence_length)` where padding elements are indicated by 0. |
There was a problem hiding this comment.
I think this docstring is incorrect in the latest version.
There was a problem hiding this comment.
I can confirm it looks all good on FA-2 end (benchmarks + tests)! thanks a lot @patrickvonplaten !
ArthurZucker
left a comment
There was a problem hiding this comment.
Are we adding the sliding window as a new feature for these models? Otherwise would just use two different classes for Mistral and the other
| config: PersimmonConfig | ||
| """ | ||
|
|
||
| # Copied from transformers.models.llama.modeling_llama.LlamaModel.__init__ with LLAMA->PERSIMMON,Llama->Persimmon,PersimmonRMSNorm->nn.LayerNorm,norm->final_layernorm,rms_final_layernorm_eps->layer_norm_eps |
There was a problem hiding this comment.
Persimmon = LLama in terms of architecture. It's alright to remove as its also very long but persimmon (and thus fuyu) will benefit from whatever happens in Llama so maybe a todo!
| sliding_window (`int`, *optional*): | ||
| Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer. |
There was a problem hiding this comment.
in this case the sliding window seems specific to Mistral so would maybe only include it in mistral's case no?
There was a problem hiding this comment.
We loose the copied-from then. I'd expect more Mistral-like models to pop up and think it's not worth removing it, see arguments here: #26792 (comment)
| if getattr(self.config, "_flash_attn_2_enabled", False): | ||
| # 2d mask is passed through the layers | ||
| attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None | ||
| else: | ||
| key_value_length = seq_length + past_key_values_length | ||
| # 4d mask is passed through the layers | ||
| if attention_mask is not None: | ||
| attention_mask = self.attn_mask_converter.to_4d( | ||
| attention_mask, seq_length, key_value_length, dtype=inputs_embeds.dtype | ||
| ) | ||
| else: | ||
| attention_mask = self.attn_mask_converter.to_causal_4d( | ||
| batch_size, seq_length, key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device | ||
| ) |
There was a problem hiding this comment.
not sure if this is cleaner than the previous version, passing a None attention mask. Seems like we could handle the None case in the class rather than here
There was a problem hiding this comment.
We were also passing a None attention mask previously for padding_mask
| if "padding_mask" in kwargs: | ||
| warnings.warn( | ||
| "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" | ||
| ) | ||
|
|
||
| # overwrite attention_mask with padding_mask | ||
| attention_mask = kwargs.pop("padding_mask") |
|
|
||
| # add lower triangular sliding window mask if necessary | ||
| if sliding_window is not None: | ||
| diagonal = past_key_values_length - sliding_window + 1 |
It's not possible really to use sliding window in Llama because it's hardcoded at initialization "sliding_window=...." for Mistral. So the user can't (and should not use)
But I do see how sliding window is arguably a bit exotic for the mask converter and if people feel strongly I can put it in Mistral's forward method instead. Overall, we do move away a bit from "single-file" policy here as the attention converter is is a general class that has more features that needed for some models. But it does make sense here since there is really not much variation for attention mask across models and it greatly helps with readability. |
|
No problem for me to leave the sliding window in the mask converter class, I indeed think we'll get to see more models leveraging the sliding window (or users that want it supported) in other architectures. |
…pper (huggingface#26792) * [Attn Mask Converter] refactor attn mask * up * Apply suggestions from code review Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com> * improve * rename * better cache * renaming * improve more * improve * fix bug * finalize * make style & make fix-copies * correct more * start moving attention_mask * fix llama * improve falcon * up * improve more * improve more * Update src/transformers/models/owlv2/modeling_owlv2.py * make style * make style * rename to converter * Apply suggestions from code review --------- Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>
What does this PR do?
For models that have Flash Attention 2 (FA2) implemented we currently pass both
padding_maskandattention_maskto the respective vanilla attention class, e.g.LlamaAttentionand to the FA2 class, e.g.LlamaFlashAttention2.However,
padding_maskis not used forLlamaAttentionandattention_maskis not used forLlamaFlashAttention2. Conceptually the two masks are the same, only thatattention_maskis a 4D mask whilepadding_maskis a 2D mask.Passing around both masks and having both masks as concepts in our codebase is ambiguous and hurts readability. In this PR, I propose to remove the concept of
padding_maskcompletely and instead just pass either a 2D or 4Dattention_maskdepending on whether we use FA2 or not.Note: An additional benefit of this PR is that it will improve the performance when using FA2 as we will not create a 4D attention mask anymore.
Benchmarks:
The following script was used to benchmark the effect this mask implementation has on forward and generate.
This PR:
Current main:
=> We don't see any drop in performance at all.
I've verified that the following tests all pass on a single GPU (RTX4090):
FA2:
and all Llama fast tests: