Adding Flash Attention 2 Support for GPT2#29226
Adding Flash Attention 2 Support for GPT2#29226amyeroberts merged 33 commits intohuggingface:mainfrom
Conversation
younesbelkada
left a comment
There was a problem hiding this comment.
Wow thanks for the great work ! At a quick glance it seems you took care very well of the copy mechanism which is quite a challenge for GPT2 !
Please find the benchmarking script: https://gist.github.com/younesbelkada/02f35734da906cc0f2389ae4f665c58f I suggest to try it out for prefill only on large sequence length - let us know with @ArthurZucker @fxmarty how it goes
Hey, I don't have a GPU and I was renting in RunPod an RTX 3090 to work on this PR, is it a problem to use the 3090 to benchmark or should I switch to an A100 (which I believe it was the GPU used in the other benchmarks at least the ones I've seen)? |
|
Thanks @EduardoPach for getting back, I think using a 3090 is fine ! |
|
@ArthurZucker I believe it should be ready for review |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
ArthurZucker
left a comment
There was a problem hiding this comment.
LGTM but we need to add a test 😉
|
@EduardoPach thanks again, what @ArthurZucker meant is an integration test similar as: for GPT2 only, would you be happy to work on that? 🙏 |
Yeah, I will add the test in the following hours |
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
|
Following hours became more like following days haha, but should be good now @ArthurZucker |
ArthurZucker
left a comment
There was a problem hiding this comment.
Almost good, left a few nits
|
|
||
|
|
||
| # Copied from transformers.models.gpt2.modeling_gpt2.GPT2Block with GPT2->DecisionTransformerGPT2 | ||
| DECISIONTRANSFORMERGPT2_ATTENTION_CLASSES = { |
There was a problem hiding this comment.
| DECISIONTRANSFORMERGPT2_ATTENTION_CLASSES = { | |
| DECISION_TRANSFORMER_GPT2_ATTENTION_CLASSES = { |
| # Copied from transformers.models.gpt2.modeling_gpt2.GPT2Block with GPT2->DecisionTransformerGPT2 | ||
| DECISIONTRANSFORMERGPT2_ATTENTION_CLASSES = { | ||
| "eager": DecisionTransformerGPT2Attention, | ||
| } |
There was a problem hiding this comment.
Where is DecisionTransformerGPT2FlashAttention2
There was a problem hiding this comment.
Haven't added it there, but added it now here 74fb9bd. However, DecisionTransformer does not support flash attention yet just I had to do these modifications to make sure nothing would break with the Copy from statements.
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
ArthurZucker
left a comment
There was a problem hiding this comment.
LGMT on final nit for the test to have explicit values
younesbelkada
left a comment
There was a problem hiding this comment.
Thanks again ! We just merged some fixes on main - could you rebase again 🙏 then we should finally merge :D sorry for all the iterations !
No worries! Done |
|
Thanks ! Hmm I can't see the rebase commit on the history, perhaps can you try again ? |
I've done |
amyeroberts
left a comment
There was a problem hiding this comment.
Thanks for adding this and making our models go brrr 🔥
Just a few small comments. The diffs in the READMEs will need to be resolved before we can merge
There was a problem hiding this comment.
There shouldn't be readme changes here. Can you make sure to rebase on main to include the mode recent changes?
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
…transformers into add-flash-attn-gpt2
amyeroberts
left a comment
There was a problem hiding this comment.
Thanks for the continued work on this!
Only thing left to do is make sure decision transformer has the updated documentation and tests
| @require_torch_gpu | ||
| @pytest.mark.flash_attn_test | ||
| @slow | ||
| def test_flash_attn_2_generate_padding_left(self): |
There was a problem hiding this comment.
The equivalent test for decision transformer should also be added
There was a problem hiding this comment.
Doesn't the test in GPT2 already cover Decision Transformer? Since, basically the usage of Flash Attention in Decision Transformer happens exactly due to GPT2Model being embedded in its architecture
There was a problem hiding this comment.
Both models should be tested. This makes sure that if anything changes upstream they remain correct, for example, inputs preparation in DecisionTransformerModel
There was a problem hiding this comment.
While adding the test for DecisionTransformer I realized that the model has two distinct xxxPreTrainedModels and that adding support for flash_attention_2 would be a bit more complicated, therefore I believe it would be better to have a specific PR to add support.
There was a problem hiding this comment.
In this case, flash attention shouldn't be added at all for the model. You can use #Ignore copy on the init so the previous attention class' method is used
There was a problem hiding this comment.
| - Enabling the *scale_attn_by_inverse_layer_idx* and *reorder_and_upcast_attn* flags will apply the training stability | ||
| improvements from [Mistral](https://github.com/stanford-crfm/mistral/) (for PyTorch only). | ||
|
|
||
| ## Usage example |
There was a problem hiding this comment.
We should have the equivalent added for decision transformer too
There was a problem hiding this comment.
See message above
|
Thank you for your hard work! Many of us are excited about the GPT-2 model supporting flash attention. May I ask when the PR is expected to be merged? |
Hey, I believe if @amyeroberts agrees with my latest message it should get merged right away 🤞 |
amyeroberts
left a comment
There was a problem hiding this comment.
Thanks for iterating - a few final places to tidy up.
docs/source/en/perf_infer_gpu_one.md
Outdated
| * [Bark](https://huggingface.co/docs/transformers/model_doc/bark#transformers.BarkModel) | ||
| * [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel) | ||
| * [Cohere](https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel) | ||
| * [DecisionTransformer](https://huggingface.co/docs/transformers/en/model_doc/decision_transformer) |
There was a problem hiding this comment.
This should be removed
| position_ids = position_ids.unsqueeze(0) | ||
|
|
||
| # GPT2Attention mask. | ||
| # Attention mask. |
There was a problem hiding this comment.
Here I would use # ignore copy - the model shouldn't have FA2 logic
| encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) | ||
| if encoder_attention_mask is None: | ||
| encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) | ||
| encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) |
There was a problem hiding this comment.
Same here above this line
amyeroberts
left a comment
There was a problem hiding this comment.
Thanks for adding this for GPT2 and iterating on a solution!
|
It seems everything is okay. May I kindly request to merge this PR? I am really looking forward to speeding up my GPT-2. If my request has added to your workload, I apologize for any inconvenience. |
c.c. @amyeroberts |
What does this PR do?
Fixes #26350
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Hey, @younesbelkada added flash attention 2 support for
GPT2. The only thing missing is the Expected speedups, could you share the code you used for the other models you added support to keep consistency?