[Flax] Adapt Flax models to new structure#9484
Conversation
…n model output, ranging 0.0010 - 0.0016
sgugger
left a comment
There was a problem hiding this comment.
Thanks a lot for cleaning this up! I like the new style!
I left a lot of nits, mostly around naming and style.
LysandreJik
left a comment
There was a problem hiding this comment.
I like that it's defined through setup and through __call__ instead of just through __call__ with nn.compact! It makes it clearer, imo.
Great job, I think it's much more readable now than it was before!
|
Will wait until #10775 is merged, then rebase and then merge. |
…into save_intermediate_flax_pr_just_in_case
|
I like the new structure but it seems this PR broke the flax example: https://github.com/huggingface/transformers/blob/master/examples/language-modeling/run_mlm_flax.py
I think we should make more test cases and make sure the examples are runnable. |
|
I am very interested in the jax/flax integration. Could you also take a look at my PR? #10796 |
* Create modeling_flax_eletra with code copied from modeling_flax_bert * Add ElectraForMaskedLM and ElectraForPretraining * Add modeling test for Flax electra and fix naming and arg in Flax Electra model * Add documentation * Fix code style * Create modeling_flax_eletra with code copied from modeling_flax_bert * Add ElectraForMaskedLM and ElectraForPretraining * Add modeling test for Flax electra and fix naming and arg in Flax Electra model * Add documentation * Fix code style * Fix code quality * Adjust tol in assert_almost_equal due to very small difference between model output, ranging 0.0010 - 0.0016 * Remove redundant ElectraPooler * save intermediate * adapt * correct bert flax design * adapt roberta as well * finish roberta flax * finish * apply suggestions * apply suggestions Co-authored-by: Chris Nguyen <anhtu2687@gmail.com>
What does this PR do?
As discussed in #9172, Flax model should get a design that is most similar to PyTorch and thus should use
def setup(...)instead ofnn.compact(...). This PR refactors the model architecture of Bert & Roberta accordingly.The next step is now to add a general conversion method flax<>pytorch which might require some more follow-up changes to the naming of the weights.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors which may be interested in your PR.