Skip to content

Fix some Flax models' hidden_states#16167

Merged
ydshieh merged 9 commits intohuggingface:masterfrom
ydshieh:fix_flax_blenderbot_hidden_outputs
Mar 15, 2022
Merged

Fix some Flax models' hidden_states#16167
ydshieh merged 9 commits intohuggingface:masterfrom
ydshieh:fix_flax_blenderbot_hidden_outputs

Conversation

@ydshieh
Copy link
Collaborator

@ydshieh ydshieh commented Mar 15, 2022

What does this PR do?

Fix some Flax models where the last element in hidden_states is different between PT/Flax version.

More context

Some models have

last_hidden_state = self.layer_norm(last_hidden_state)

In Pytorch version, the returned hidden_states have this last_hidden_state (after layer norm) as the last element.
In Flax version, the last element of the returned hidden_states is the one before the layer norm.

This PR fixes this inconsistency (by using the PyTorch logic).

@ydshieh ydshieh changed the title Fix flax blenderbot hidden outputs Fix some Flax models' hidden_states Mar 15, 2022
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Mar 15, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch! Thanks a lot for fixing this.

Just left a nit :)

Comment on lines +721 to +724
if not output_hidden_states:
return (last_hidden_states,) + outputs[1:]
else:
return (last_hidden_states, hidden_states) + outputs[2:]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) Not a big fan of nested ifs, maybe simplify this a bit

Comment on lines +801 to +804
if not output_hidden_states:
return (last_hidden_states,) + outputs[1:]
else:
return (last_hidden_states, hidden_states) + outputs[2:]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above

Copy link
Collaborator Author

@ydshieh ydshieh Mar 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @patil-suraj , is this better?

if not return_dict:
outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
return tuple(v for v in outputs if v is not None)

(when it returns tuple + when it needs extra processing as in this PR, I always have trouble to make it cleaner. Things get much easier if the internal components return dict or named tuple, and only change the format at the top level components for the users - but I don't think we are going to do so, at least not soon)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this looks good, I will change other places.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good!

Things get much easier if the internal components return dict or named tuple, and only change the format at the top level components for the users

This is a good idea, I think we can change the internal modules to only return either dict or Tuple

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback :-) @patil-suraj . I will apply the same change to other places to finish this PR.
About changing internal components (in general), let's have a discussion later with other members.

Comment on lines +778 to +781
if not output_hidden_states:
return (last_hidden_states,) + outputs[1:]
else:
return (last_hidden_states, hidden_states) + outputs[2:]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@ydshieh ydshieh merged commit ea05d67 into huggingface:master Mar 15, 2022
@ydshieh ydshieh deleted the fix_flax_blenderbot_hidden_outputs branch March 15, 2022 18:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants