Make Flax pt-flax equivalence test more aggressive#15841
Make Flax pt-flax equivalence test more aggressive#15841patil-suraj merged 14 commits intohuggingface:masterfrom
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
068a659 to
eefded6
Compare
tests/test_modeling_flax_common.py
Outdated
There was a problem hiding this comment.
(nit) could we call this fx_outputs and pt_outputs instead of fxo and pto ?
There was a problem hiding this comment.
Sure, you and Sylvain told me the same thing :)
tests/test_modeling_flax_common.py
Outdated
There was a problem hiding this comment.
(nit) maybe also document fxo and pto
tests/test_modeling_flax_common.py
Outdated
There was a problem hiding this comment.
I'm not sure if 1e-5 will work for all models especially on TPU/GPU since JAX does some approximations on TPU so the output can diverge. cf #15754
What do you think @patrickvonplaten
There was a problem hiding this comment.
I will check on GPU VM - currently I am doing this for PT/TF.
There was a problem hiding this comment.
Yes I think a precision of 1e-3 would be better
tests/test_modeling_flax_common.py
Outdated
There was a problem hiding this comment.
ConvNext is not available in flax (yet!)
There was a problem hiding this comment.
(I know :), I am just copying from PT/TF. Do you want me to remove it for now?)
tests/test_modeling_flax_common.py
Outdated
There was a problem hiding this comment.
I don't think we need docstrings for test functions
tests/test_modeling_flax_common.py
Outdated
There was a problem hiding this comment.
This we can delete
|
Also very surprised that the tests are all passing. @ydshieh - can you double check that the tests are actually passing for most models. Some edge-case models that could be tested locally:
|
|
@patil-suraj @patrickvonplaten There are a few things to fix in this PR (same for the PT/TF) - some tests are just ignored by my mistakes. |
eefded6 to
4603f5b
Compare
|
[Updated Info.]
|
Sorry, but please ignore the above claim. The tests ran flax models on CPU because the GPU version of Jax/Flax were not installed. Running on |
|
The documentation is not available anymore as the PR was closed or merged. |
37b3720 to
27909ca
Compare
|
Update: After rebasing on a more recent commit on master ( I ran this new test on GPU (inside docker container that we use for CI GPU testing + with
Think this PR is ready! @patil-suraj @patrickvonplaten (After installing Error logs |
patil-suraj
left a comment
There was a problem hiding this comment.
Thanks lot for working on this!
Could you just rebase and push again ? This will make the CI green :)
27909ca to
53605f9
Compare
|
It's all green! I will approve my own PR too: LGTM! |
* Make test_equivalence_pt_to_flax more aggressive * Make test_equivalence_flax_to_pt more aggressive * don't use to_tuple * clean-up * fix missing test cases + testing on GPU * fix conversion * fix `ValueError: assignment destination is read-only` * Add type checking * commit to revert later * Fix * fix * fix device * better naming * clean-up Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
What does this PR do?
Make Flax pt-flax equivalence test more aggressive. (Similar to #15839 for PT/TF).
It uses
output_hidden_states=Trueandoutput_attentions=Trueto test all output tensors (in a recursive way).Also, it lowers the tolerance from
4e-2to1e-5. (From the experience I gained in PT/TF test, if an error >1e-5, I always found a bug to fix).(A bit) surprisingly, but very good news: unlike PT/TF, there is no PT/Flax inconsistency found by this more aggressive test! (@patil-suraj must have done a great job on flax models :-) )
Flax: @patil-suraj @patrickvonplaten