Skip to content

Fix missing output_attentions in PT/Flax equivalence test#16271

Merged
ydshieh merged 10 commits intohuggingface:mainfrom
ydshieh:fix_pt_flax_equivalence_tests
Mar 29, 2022
Merged

Fix missing output_attentions in PT/Flax equivalence test#16271
ydshieh merged 10 commits intohuggingface:mainfrom
ydshieh:fix_pt_flax_equivalence_tests

Conversation

@ydshieh
Copy link
Collaborator

@ydshieh ydshieh commented Mar 19, 2022

What does this PR do?

In a previous PR #15841, output_attentions was not set (I accidentally removed the whole block containing it).
This PR sets output_attentions to make the test more thorough.

The test still runs successfully with 1e-5 on both CPU/GPU. However, see the 2nd points in the remarks below.

It also adds has_attentions attribute to FlaxModelTesterMixin (as done in PyTorch's ModelTesterMixin).

Remarks:

  • In a follow up PR, we might use has_attentions in some existing methods (to make sure the attentions are only tested if has_attentions is True), see [Tests] Add attentions_option to ModelTesterMixin #15909
  • There are 4 Flax model testers overwrite the Flax common test_equivalence_pt_to_flax and test_equivalence_flax_to_pt.
    • I will update them in a next PR.
    • These include FlaxGPTJ and FlaxXGLM, which will fail with 1e-5. I need to debug them.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@ydshieh ydshieh force-pushed the fix_pt_flax_equivalence_tests branch from 3d07016 to cb6459b Compare March 20, 2022 11:33
@ydshieh ydshieh marked this pull request as ready for review March 21, 2022 07:47
@ydshieh ydshieh changed the title [WIP] Fix missing output_attentions Fix missing output_attentions in PT/Flax equivalence test Mar 21, 2022
Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks a lot!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit too hacky for me here. Can't we just overwrite the test in test_modeling_flax_big_bird.py?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't like this too much here either. Can't we check if there is a output_attentions in the signature of the forward function and if that's the case then we set config.output_attentions=True? This way we have 1 dependency less

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has_attentions attribute was introduced in ModelTesterMixin (#15909) (and then in TFModelTesterMixin by me #16259).

Think it would be good to have the same approach for testing across the 3 frameworks. Let me know if you still prefer the other approach(es).

cc @NielsRogge @sgugger for further comments if any.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, let's use existing attributes and make the three testers consistent with each other.

@ydshieh ydshieh force-pushed the fix_pt_flax_equivalence_tests branch from ede2bea to 151860f Compare March 25, 2022 16:19

# send pytorch model to the correct device
pt_model_loaded.to(torch_device)
pt_model_loaded.eval()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't forget to set to eval for re-loaded pt model

@ydshieh
Copy link
Collaborator Author

ydshieh commented Mar 25, 2022

Think this (quite small) PR is ready. Nothing particular but adding the missing config.output_attentions = self.has_attentions.
The super() thing was discussed in #16280.

Will merge it today.

dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})

# (Copied from tests.test_modeling_common.ModelTesterMixin.check_outputs)
Copy link
Collaborator Author

@ydshieh ydshieh Mar 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The intention is only to add this information, not mean to work with the current version of make fix-copies.

@sgugger Are you OK with this comment? Otherwise I can just remove it.

@ydshieh ydshieh merged commit aebca69 into huggingface:main Mar 29, 2022
@ydshieh ydshieh deleted the fix_pt_flax_equivalence_tests branch March 29, 2022 15:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants