[FlaxBert] Add ForCausalLM#16995
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
9c9e49b to
62173c7
Compare
| self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") | ||
| for fx_output, pt_output in zip(fx_outputs, pt_outputs): | ||
| self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-5) | ||
| self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2) |
There was a problem hiding this comment.
@ydshieh is 1e-5 now the default testing precision?
There was a problem hiding this comment.
Found a bug in the FlaxBertModelTester and fixed! Thresholds now back to 1e-5 and passing (even with the randomly initialised decoder attention mask) :-)
There was a problem hiding this comment.
Yes. So far for anything higher than 1e-5, I was able to find some issues, either in the models, or in the model testes.
patrickvonplaten
left a comment
There was a problem hiding this comment.
Looks good to me - @sanchit-gandhi could you check though which models don't pass with 1e-5 and ideally why?
Overall 4e-2 is fine for me though cc @ydshieh what do you think?
Keep |
patrickvonplaten
left a comment
There was a problem hiding this comment.
Cool, feel free to merge @sanchit-gandhi
* [FlaxBert] Add ForCausalLM * make style * fix output attentions * Add RobertaForCausalLM * remove comment * fix fx-to-pt model loading * remove comment * add modeling tests * add enc-dec model tests * add big_bird * add electra * make style * make repo-consitency * add to docs * remove roberta test * quality * amend cookiecutter * fix attention_mask bug in flax bert model tester * tighten pt-fx thresholds to 1e-5 * add 'copied from' statements * amend 'copied from' statements * amend 'copied from' statements * quality
* [FlaxBert] Add ForCausalLM * make style * fix output attentions * Add RobertaForCausalLM * remove comment * fix fx-to-pt model loading * remove comment * add modeling tests * add enc-dec model tests * add big_bird * add electra * make style * make repo-consitency * add to docs * remove roberta test * quality * amend cookiecutter * fix attention_mask bug in flax bert model tester * tighten pt-fx thresholds to 1e-5 * add 'copied from' statements * amend 'copied from' statements * amend 'copied from' statements * quality
What does this PR do?
Adds cross-attention blocks to the following module classes:
Adds the following ForCausalLM model classes:
Adds the following model tests:
Note: FlaxBertForCausalLM is excluded due to the name mismatch with the PyTorch equivalent BertLMHeadModel. It is implicitly tested through the FlaxRobertaForCausalLM model tests, as well as in the following encoder-decoder model tests: