Conversation
examples/research_projects/jax-projects/big_bird/bigbird_flax.py
Outdated
Show resolved
Hide resolved
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
…rmers into add_fa2_bart
|
Verified that FA2 works by checking Whisper. Bart's attention is exactly the same as Whisper so it should as well. I will run some better benchmarks later. @ArthurZucker @younesbelkada could you do a first review here, just for the following files:
It would be nice to agree on these files before running Some comments:
|
ArthurZucker
left a comment
There was a problem hiding this comment.
Looks good yeah!
regarding your comments, I totally agree regarding the padding mask, this was my initial concern here. Llama needed less tolerance but let's update it. Otherwise Looks good, let's make sure the attention is as clean as possible as it will be the reference for cross attention.
|
|
||
| def _flash_attention_forward( |
There was a problem hiding this comment.
| def _flash_attention_forward( | |
| # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward | |
| def _flash_attention_forward( |
this is copied from as well no?
There was a problem hiding this comment.
Actually we need a new causal function argument here to differentiate between non-causal (encoder) and causal (decoder) attention
| @@ -2797,16 +2797,35 @@ def test_flash_attn_2_inference(self): | |||
| dummy_input = torch.LongTensor([[1, 2, 3, 4, 5]]).to(torch_device) | |||
There was a problem hiding this comment.
might need decoder input ids as well for cross attention testing?
|
Update: The PR now works for Bart. @ArthurZucker @younesbelkada @fxmarty @LysandreJik could you give the design chosen for BART a look here and if ok, I'll apply it to all other Bart-like models. Please only review |
fxmarty
left a comment
There was a problem hiding this comment.
Looks good! There are probably some docstring (e.g. BartDecoderLayer) whose attention_mask doc should be modified accordingly.
| num_heads: int, | ||
| dropout: float = 0.0, | ||
| is_decoder: bool = False, | ||
| is_causal: bool = False, |
There was a problem hiding this comment.
Personal taste, but I would add this arg after bias in case somebody is using positional arguments.
| # BartFlashAttention2 attention does not support output_attentions | ||
| output_attentions = False |
There was a problem hiding this comment.
Don't know how it was for llama, but I would raise an error here in case output_attentions is True
There was a problem hiding this comment.
Yes fair, problem though is that by now it's backwards breaking
| # TODO: Bart does not have dropout in the config?? | ||
| # It is recommended to use dropout with FA according to the docs | ||
| # when training. | ||
| dropout_rate = 0.0 # if not self.training else self.attn_dropout |
There was a problem hiding this comment.
I think Bart has some dropout:
ArthurZucker
left a comment
There was a problem hiding this comment.
Only reviewed the bart modelling file, looks good overall!
If the attention logic happens in the attention class might be slightly better, but otherwise it nice that we expose in a better way how attention masks need to be processed! Thanks.
| if attention_mask is not None: | ||
| attention_mask = self.causal_attn_mask_converter.to_4d( | ||
| attention_mask, input_shape[-1], key_value_length, dtype=inputs_embeds.dtype | ||
| ) | ||
| else: | ||
| attention_mask = self.causal_attn_mask_converter.to_causal_4d( | ||
| input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device | ||
| ) |
There was a problem hiding this comment.
would be nice if the to_causal_4d supports feeding a mask and takes care of this if else no?
There was a problem hiding this comment.
Hmm but the mask is None here and then I need to pass all these shapes anyways
| # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] | ||
| encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) | ||
| if getattr(self.config, "_flash_attn_2_enabled", False): | ||
| encoder_attention_mask = encoder_attention_mask if (encoder_attention_mask is not None and 0 in encoder_attention_mask) else None |
There was a problem hiding this comment.
think a comment would be nice to say why we don't pass the mask to FA if there is not 0 values in it.
| if getattr(config, "_flash_attn_2_enabled", False): | ||
| self.encoder_attn = BartFlashAttention2( | ||
| embed_dim=self.embed_dim, | ||
| num_heads=config.decoder_attention_heads, | ||
| dropout=config.attention_dropout, | ||
| is_decoder=True, | ||
| config=config, | ||
| ) | ||
| else: | ||
| self.encoder_attn = BartAttention( | ||
| embed_dim=self.embed_dim, | ||
| num_heads=config.decoder_attention_heads, | ||
| dropout=config.attention_dropout, | ||
| is_decoder=True, | ||
| config=config, | ||
| ) |
There was a problem hiding this comment.
think a mapping BERT_ATTENTIONS["attention_class"] will be cleaner long term if we add sdpa, flash decoding etc, specifically given that the init arguments are consistent (and we want this to always be the case)
There was a problem hiding this comment.
Sure that makes sense!
What does this PR do?
Add FA2 to all Bart-like models