Skip to content

[Flax] Adapt Flax models to new structure#9484

Merged
patrickvonplaten merged 24 commits intomasterfrom
save_intermediate_flax_pr_just_in_case
Mar 18, 2021
Merged

[Flax] Adapt Flax models to new structure#9484
patrickvonplaten merged 24 commits intomasterfrom
save_intermediate_flax_pr_just_in_case

Conversation

@patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented Jan 8, 2021

What does this PR do?

As discussed in #9172, Flax model should get a design that is most similar to PyTorch and thus should use def setup(...) instead of nn.compact(...). This PR refactors the model architecture of Bert & Roberta accordingly.

The next step is now to add a general conversion method flax<>pytorch which might require some more follow-up changes to the naming of the weights.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors which may be interested in your PR.

@patrickvonplaten patrickvonplaten marked this pull request as ready for review March 16, 2021 18:10
@patrickvonplaten patrickvonplaten changed the title [WIP] Save intermediate flax pr just in case [WIP] Adapt Flax models to new structure Mar 16, 2021
@patrickvonplaten patrickvonplaten changed the title [WIP] Adapt Flax models to new structure [Flax] Adapt Flax models to new structure Mar 16, 2021
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for cleaning this up! I like the new style!
I left a lot of nits, mostly around naming and style.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like that it's defined through setup and through __call__ instead of just through __call__ with nn.compact! It makes it clearer, imo.

Great job, I think it's much more readable now than it was before!

@patrickvonplaten
Copy link
Contributor Author

Will wait until #10775 is merged, then rebase and then merge.

…into save_intermediate_flax_pr_just_in_case
@patrickvonplaten patrickvonplaten merged commit 0b98ca3 into master Mar 18, 2021
@patrickvonplaten patrickvonplaten deleted the save_intermediate_flax_pr_just_in_case branch March 18, 2021 06:44
@merrymercy
Copy link
Contributor

merrymercy commented Apr 13, 2021

@patrickvonplaten

I like the new structure but it seems this PR broke the flax example: https://github.com/huggingface/transformers/blob/master/examples/language-modeling/run_mlm_flax.py

TypeError: __init__() got an unexpected keyword argument 'dropout_rate'

I think we should make more test cases and make sure the examples are runnable.

@merrymercy
Copy link
Contributor

merrymercy commented Apr 13, 2021

I am very interested in the jax/flax integration. Could you also take a look at my PR? #10796
If you are collaborative and welcome contributions from me, I can contribute more and improve the flax examples.

Iwontbecreative pushed a commit to Iwontbecreative/transformers that referenced this pull request Jul 15, 2021
* Create modeling_flax_eletra with code copied from modeling_flax_bert

* Add ElectraForMaskedLM and ElectraForPretraining

* Add modeling test for Flax electra and fix naming and arg in Flax Electra model

* Add documentation

* Fix code style

* Create modeling_flax_eletra with code copied from modeling_flax_bert

* Add ElectraForMaskedLM and ElectraForPretraining

* Add modeling test for Flax electra and fix naming and arg in Flax Electra model

* Add documentation

* Fix code style

* Fix code quality

* Adjust tol in assert_almost_equal due to very small difference between model output, ranging 0.0010 - 0.0016

* Remove redundant ElectraPooler

* save intermediate

* adapt

* correct bert flax design

* adapt roberta as well

* finish roberta flax

* finish

* apply suggestions

* apply suggestions

Co-authored-by: Chris Nguyen <anhtu2687@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants