Fix missing output_attentions in PT/Flax equivalence test#16271
Fix missing output_attentions in PT/Flax equivalence test#16271ydshieh merged 10 commits intohuggingface:mainfrom
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
3d07016 to
cb6459b
Compare
tests/test_modeling_flax_common.py
Outdated
There was a problem hiding this comment.
This is a bit too hacky for me here. Can't we just overwrite the test in test_modeling_flax_big_bird.py?
tests/test_modeling_flax_common.py
Outdated
There was a problem hiding this comment.
Don't like this too much here either. Can't we check if there is a output_attentions in the signature of the forward function and if that's the case then we set config.output_attentions=True? This way we have 1 dependency less
There was a problem hiding this comment.
This has_attentions attribute was introduced in ModelTesterMixin (#15909) (and then in TFModelTesterMixin by me #16259).
Think it would be good to have the same approach for testing across the 3 frameworks. Let me know if you still prefer the other approach(es).
cc @NielsRogge @sgugger for further comments if any.
There was a problem hiding this comment.
Yes, let's use existing attributes and make the three testers consistent with each other.
Co-authored-by: Suraj Patil <surajp815@gmail.com>
ede2bea to
151860f
Compare
|
|
||
| # send pytorch model to the correct device | ||
| pt_model_loaded.to(torch_device) | ||
| pt_model_loaded.eval() |
There was a problem hiding this comment.
don't forget to set to eval for re-loaded pt model
|
Think this (quite small) PR is ready. Nothing particular but adding the missing Will merge it today. |
| dict_inputs = self._prepare_for_class(inputs_dict, model_class) | ||
| check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) | ||
|
|
||
| # (Copied from tests.test_modeling_common.ModelTesterMixin.check_outputs) |
There was a problem hiding this comment.
The intention is only to add this information, not mean to work with the current version of make fix-copies.
@sgugger Are you OK with this comment? Otherwise I can just remove it.
What does this PR do?
In a previous PR #15841,
output_attentionswas not set (I accidentally removed the whole block containing it).This PR sets
output_attentionsto make the test more thorough.The test still runs successfully with
1e-5on both CPU/GPU. However, see the 2nd points in the remarks below.It also adds
has_attentionsattribute toFlaxModelTesterMixin(as done in PyTorch'sModelTesterMixin).Remarks:
has_attentionsin some existing methods (to make sure the attentions are only tested ifhas_attentionsisTrue), see [Tests] Add attentions_option to ModelTesterMixin #15909test_equivalence_pt_to_flaxandtest_equivalence_flax_to_pt.FlaxGPTJandFlaxXGLM, which will fail with1e-5. I need to debug them.