[BERT] Add support for sdpa#28802
Conversation
|
Hey @ArthurZucker @younesbelkada I was thinking SDPA (#28005) could be a good addition to BERT, so I drafted this change. It doesn't look too hairy so far. As @ArthurZucker mentioned, BERT doesn't have a lot of params so there might not be much of a speedup, but this didn't look too difficult to implement so I figured whatever little improvement might still be helpful (as an aside, there's been some benchmarking of Flash Attention on training other implementations of BERT, and it still shows decent improvements). Can you let me know if this is worth pursuing? If so, I'll add the tests and also fix the fix-copies dependencies. Thanks! |
There was a problem hiding this comment.
This is fixed in torch 2.2.0 I think, maybe I should check for it and skip the calls?
There was a problem hiding this comment.
I think it is fine to leave. We should probably bump the requirement for SDPA to torch>=2.2 in the future.
There was a problem hiding this comment.
This got me thinking, and I ran an additional set of benchmarking, given that FA2 is supported and the contiguous bug is fixed in 2.2.0: training and inference.
Both training and inference were ~5% faster with torch==2.2.0 (FA2 should be supported). I also tried out gating the .contiguous() requirement and saw an additional ~5-10% gain on top of that.
if version.parse(get_torch_version()) < version.parse("2.2.0")
query_layer = query_layer.contiguous()
key_layer = key_layer.contiguous()
value_layer = value_layer.contiguous()
I'm leaning towards adding the if-statement to gate the call, so users who upgrade to torch=2.2.0 first can get the benefits right away (before we set the min torch version to 2.2.0). WDYT?
There was a problem hiding this comment.
I added the if-statement for 2.2.0 in there. If you don't think it's a good idea, let me know and I'll remove it.
|
I think a good way to se if it is worth the shot is to benchmark your code and check if you have speedups in different contexts! |
|
Sounds good, lemme look into that |
|
@ArthurZucker I did some training and inference benchmarking for my change and posted the results in the PR description. It looks like there are decent improvements across the board (percentage-wise, but I think the improvements would add up if we're doing a lot of training/inferencing). I think it could be a good addition. Thoughts? |
|
Sounds like a good addition then! I'll let @fxmarty review and will be doing the final pass! |
|
Yes, it's similar. SDPA is built into pytorch, and can support Flash Attention (1) depending on the environment. AFAIK Flash Attention 2 isn't supported in SDPA yet, but there is a possibility for it to be supported down the road (but that should be built into pytorch already, and shouldn't need many changes from our end). |
|
Thanks, I think it is now |
|
Oh nice, so I guess we could get FA2 for free eventually (when we upgrade pytorch). Thanks for the links to similar work. I think they could cause some merge conflicts, so I'll message them and try to resolve it before it goes in. |
There was a problem hiding this comment.
I would probably move the Copied from just to the __init__ and other methods, but not forward. For the forward, you can probably just add a comment that it is adapted from bert/roberta and once bridge_tower supports sdpa we can put back to copied from.
There was a problem hiding this comment.
There seems to be 8 methods that copy-from BertMode#forward() exactly and has this section of change.
I won't mind adding SDPA to them as well once this goes in and reinstating the copy-from. It shouldn't be that difficult (famous last words)
There was a problem hiding this comment.
I've removed the fix-copies from the instances, and so the logic for sdpa attention masks should only be in BertModel now.
src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
@ArthurZucker there are create_extended_attention_mask_for_decoder, invert_attention_mask, get_extended_attention_mask methods in modeling_utils.py that should probably be deprecated / redirect to modeling_attn_mask_utils.py.
There was a problem hiding this comment.
Yea, I agree.
It'd be great if we could mark those old methods as deprecated, and slowly update them once we verify that the old methods and the new methods are always returning the same results.
There was a problem hiding this comment.
For the updated_attention_mask for sdpa, why can't we keep the previous logic and just do:
# Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = causal_mask.mul(~torch.all(causal_mask == torch.finfo(embedding_output.dtype).min, dim=-1, keepdim=True)).to(
dtype
)(from Llama)?
Not super fan of the complexity of _prepare_4d_causal_attention_mask_for_sdpa, and we should not add it in our new code IMO.
There was a problem hiding this comment.
This code was changed to pass the fx tracing test (in common tests).
It would be good if you can help double check the logic here. I think the idea here is that we'll still have to use our own attention mask (rather than None) when tracing is active. The previous "pass" would cause the function to end without any return statements, which would have defaulted to None.
There was a problem hiding this comment.
It looks OK to me, cc @fxmarty to confirm.
AFAICT, the difference here is coming from the additional isinstance(mask, torch.fx.Proxy) in the is_tracing_check. I don't believe the reworking to remove pass should affect anything - the new code is equivalent.
There was a problem hiding this comment.
Yes it is fine, see
There was a problem hiding this comment.
This fix was added due to a test failure that uncovered an existing bug.
The head was initialized but the weights weren't retied as necessary. This was causing self.decoder.bias to be different from self.bias. When loading the pretrained model with low_cpu_mem_usage=True, the self.decoder.bias had uninitiated params (with device=meta) whereas self.bias was set properly (with device=cpu)
I'm slightly concerned this will affect the output some users see when using this model. Please let me know what you think about this.
There was a problem hiding this comment.
I pulled this out to its own PR here:
#28948
This issue is unrelated to SDPA, but was just uncovered by a SPDA test, so I just pulled it out to its own PR.
There was a problem hiding this comment.
Addition looks OK to me - thanks for digging into this.
I'm slightly concerned this will affect the output some users see when using this model. Please let me know what you think about this.
Could you expand on what you think might be an issue?
There was a problem hiding this comment.
I was initially concerned that users were loading and using the model with a wrong bias (ie. device=meta), and this fix to use the correct bias will cause the results to change between versions.
However, that seems unlikely after playing around with this a bit more - turns out it was quite difficult to run the model when the bias had device=meta, so I doubt anyone was actually running the model in this particular configuration before the fix.
tests/test_modeling_common.py
Outdated
There was a problem hiding this comment.
The self._prepare_for_class is necessary to support the BertForMultipleChoice model.
|
I've rebased off of head and marked as ready for review. I had to dig through a couple of issues to get the tests to pass, let me now if you want to chat about any of them. Thanks! |
|
The tests are passing now. I also verified that test_modeling_bert passes with RUN_SLOW=1 (which contains the tests to ensure model output is the same for eager and sdpa attentions) Please take another look when you get a chance. Thanks! |
amyeroberts
left a comment
There was a problem hiding this comment.
Thanks for all the work adding this @hackyon as well as the additional work to dig into weird errors and find solutions. Great work!
Some general comments:
- Let's wait for the merging of #28948 before merging this in
- It would be good to add the performance numbers in the PR description to BERT's model page, similar to what's done for Flash Attention e.g. [here](https://huggingface.co/docs/transformers/v4.37.2/en/model_doc/gpt_neox#using-flash-attention-2.
test_eager_matches_sdpa_inferenceshould be run for all existing models with SDPA implemented to confirm compatibility with the change inprocessed_inputs- We shouldn't be setting
self._use_sdpathat don't have an SDPA attention class. We can just about get away with it for the models which have an attention dict, but not for the other models.
There was a problem hiding this comment.
It looks OK to me, cc @fxmarty to confirm.
AFAICT, the difference here is coming from the additional isinstance(mask, torch.fx.Proxy) in the is_tracing_check. I don't believe the reworking to remove pass should affect anything - the new code is equivalent.
There was a problem hiding this comment.
Addition looks OK to me - thanks for digging into this.
I'm slightly concerned this will affect the output some users see when using this model. Please let me know what you think about this.
Could you expand on what you think might be an issue?
src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py
Outdated
Show resolved
Hide resolved
|
Oh wow |
ArthurZucker
left a comment
There was a problem hiding this comment.
LGTM, let's rebase on main!
|
Thanks! I merged with main/HEAD, and re-ran the RUN_SLOW tests for both bert and also for test_eager_matches_sdpa_inference and they work as expected. There were existing failures for test_eager_matches_sdpa_inference with RUN_SLOW on main/HEAD, but nothing new introduced by this change. I'm not sure about this test_pipelines_tf failure. I haven't touched any code with tf, and I was able to get the failing test test_stop_sequence_stopping_criteria to pass locally, so I'm thinking it's a flake or unrelated to this change. |
|
Hi @hackyon - great to see this ready to merge! The generation tests aren't related to this diff and are failing on other PRs. We're working to push a fix to main - will let you know when resolved, you can rebase and hopefully we have full 🟢 for merging 🤗 |
|
Thanks @amyeroberts @ArthurZucker Just remerged with main/HEAD, and the unrelated failing TF pipeline test now passes. I checked the bert tests again with RUN_SLOW for good measure, and they continue to pass. Let me know if there's anything else I could do here. Thanks! |
|
@ArthurZucker Please let me know if there's anything else you'd like me to do for this PR. Thanks! |
|
Remerged with the latest main, and fixed a test. @ArthurZucker @amyeroberts @fxmarty Please let me know if there's anything I can do here. |
|
@hackyon Everything's green and two approvals, so we're good to merge. Thanks for all the effort in adding this and iterating with us. It's great to have this added to one of the most popular models ❤️ |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Thanks @amyeroberts for the merge! 🎉 I appreciate all the help from @fxmarty, @ArthurZucker, and you in getting this PR merged 🙏 I see you've submitted #30506 as a follow-up, and thank you for covering that. Please let me know if there's any other follow-up work, and I'd be happy to look into it. |
|
As I mentioned previously, I've also drafted a PR for adding SDPA support to RoBERTa-based models at #30510. Almost all of the changes are "Copied from" BERT, and so there is a little less room for error. |
|
I appreciate your job! As Esm is a Bert-base model, I think sdpa can be add to Esm with little modification. |
What does this PR do?
Adding support for SDPA (scaled dot product attention) for Bert. More context in #28005.
Benchmarking Results on A100-80GB, CPUx12, RAM 96.6GB, OS Ubuntu 22.04, using BertLMHeadModel
Training benchmark based on fxmarty's script:
Inference benchmark based on fxmarty's script:
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@ArthurZucker @younesbelkada
(cc @fxmarty)