Skip to content

Improve BERT-like models performance with better self attention#9124

Merged
jplu merged 8 commits intohuggingface:masterfrom
jplu:tf-einsumdense
Dec 21, 2020
Merged

Improve BERT-like models performance with better self attention#9124
jplu merged 8 commits intohuggingface:masterfrom
jplu:tf-einsumdense

Conversation

@jplu
Copy link
Copy Markdown
Contributor

@jplu jplu commented Dec 15, 2020

What does this PR do?

This PR updates the way we implement the self attention layers in order to be aligned on the original BERT performance. Small breaking change, this improvement needs at least TF 2.3. This change has already been discussed with @thomwolf, and he agreed. But still needs the approval of @LysandreJik @patrickvonplaten and @sgugger

@patrickvonplaten I have removed the comment for check_copies in the Longformer model because I don't know enough this model to apply the proper changes, I will apply this update to one model by one model for the ones I know but can you take this one?

@jlei2 as I'm on Windows, unfortunately the GPU profiling is not yet available in WSL, can you clone this branch and be sure that everythings works like expected with your benchmark? Thanks!!

Fixes # (issue)

#6771

@jplu jplu marked this pull request as draft December 15, 2020 13:14
@jplu
Copy link
Copy Markdown
Contributor Author

jplu commented Dec 15, 2020

A Python profiling call gives the following improvements:

model = TFBertModel.from_pretrained("bert-base-cased")

# With the improvements
cProfile.run("model(model.dummy_inputs)") 
54591 function calls (53774 primitive calls) in 0.064 seconds

# Currently on master
cProfile.run("model(model.dummy_inputs)")
76166 function calls (75204 primitive calls) in 0.095 seconds

@jplu jplu mentioned this pull request Dec 15, 2020
Copy link
Copy Markdown
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this! I just have some cosmetic change nits, since I'm annoying and can't see things that are longer than 119 chars ;-)

Comment thread setup.py Outdated
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you rebase after merging #9120, we canc lean up the version test in tf_optimization.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do that in a PR that will take care of only doing TF >=2.3 compliancy, instead?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the assert better than raising the ValueError? I liked the part above better but it's mostly because it fits with the 119 char limits. If we keep the assert, could you just split the line to respect that char limit?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't mind to remove the assert and put back raising a value error. Will do this!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in the last commit!

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as for the other assert.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in the last commit!

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as for the other assert.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in the last commit!

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in the last commit!

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is Longformer not included in the change?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because I don't know well enough this model, I prefer to let @patrickvonplaten handle this one.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The intermediate layers of Longformer are 1-to-1 the same than BERT and there should be no problem to keep those lines. I'd be surprised if leaving this line would throw an error tbh. Did you try just leaving them? The only difference in Longformer is the Self-attention layer and all none of those copy statements concern the self-attention layer, so IMO we should leave the statements and run make fix-copies

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh indeed it works for the Intermediate layer! Only the Self attention still needs to be updated accordingly :)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that I'm happy to do in another PR - would be amazing if you could open a one-liner issue about it and tag me :-)

@jplu jplu marked this pull request as ready for review December 15, 2020 15:17
Copy link
Copy Markdown
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! I am sure you already did this, but we before merging we should be sure of two things:

  1. All the slow tf BERT + BERT-like models tests are passing
  2. "Old" pre-trained models tf_model.h5 files that were saved with tf < 2.3 can be loaded into the new layer design + tf1 models (.ckpt) files can be loaded into the new model layer.

I don't really see a reason why neither 1) nor 2) should not work, but just to be sure it'd be great to test those quickly :-)

And I think we can leave all the longformer copy statements -> there shouldn't be a problem :-)

@jplu
Copy link
Copy Markdown
Contributor Author

jplu commented Dec 16, 2020

Thanks @patrickvonplaten !!

  1. Slow tests are passing for these models
  2. I confirm that "Old" pre-trained models tf_model.h5 files that were saved with tf < 2.3 can be loaded into the new layer design

I haven't tested the tf1 models, you mean testing the load_tf_weights_in_bert in the modeling_bert.py file?

@jplu
Copy link
Copy Markdown
Contributor Author

jplu commented Dec 16, 2020

@jlei2 has confirmed that now everything works as expected in the profiler and benchmark 👍 #6771 (comment)

@patrickvonplaten
Copy link
Copy Markdown
Contributor

2. "Old" pre-trained models tf_model.h5 files that were saved with tf < 2.3 can be loaded into the new layer des

Yeah I mean loading a tf .ckpt file using the from_pretrained(...) method. The from_pretrained(...) method automatically uses the correct functions to load .ckpt. I think the easiest way would be to download one of the zips of the official google bert: https://github.com/google-research/bert#bert and quickly check that it can be loaded and that the output on this branch and on master is the same.

@patrickvonplaten
Copy link
Copy Markdown
Contributor

  1. "Old" pre-trained models tf_model.h5 files that were saved with tf < 2.3 can be loaded into the new layer des

Yeah I mean loading a tf .ckpt file using the from_pretrained(...) method. The from_pretrained(...) method automatically uses the correct functions to load .ckpt. I think the easiest way would be to download one of the zips of the official google bert: https://github.com/google-research/bert#bert and quickly check that it can be loaded and that the output on this branch and on master is the same.

Ok as discussed offline TF1 checkpoints cannot even be loaded into TF2 at the moment (only if one goes through PT), so this PR is good to go for me!

Copy link
Copy Markdown
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very clean, and the performance improvements are amazing! Thanks for checking that the slow tests pass and that the previous checkpoints can still be loaded.

Great job, thank you for working on this!

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use the tf.keras.layers.experimental.EinsumDense and keep the copy?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use the tf.keras.layers.experimental.EinsumDense and keep the copy?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one is not possible because the input/output shapes won't be anymore compatible.

@jplu jplu merged commit 5a8a4eb into huggingface:master Dec 21, 2020
@jplu jplu deleted the tf-einsumdense branch December 21, 2020 13:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants