Fix some Flax models' hidden_states#16167
Conversation
hidden_states
|
The documentation is not available anymore as the PR was closed or merged. |
patil-suraj
left a comment
There was a problem hiding this comment.
Great catch! Thanks a lot for fixing this.
Just left a nit :)
| if not output_hidden_states: | ||
| return (last_hidden_states,) + outputs[1:] | ||
| else: | ||
| return (last_hidden_states, hidden_states) + outputs[2:] |
There was a problem hiding this comment.
(nit) Not a big fan of nested ifs, maybe simplify this a bit
| if not output_hidden_states: | ||
| return (last_hidden_states,) + outputs[1:] | ||
| else: | ||
| return (last_hidden_states, hidden_states) + outputs[2:] |
There was a problem hiding this comment.
Hi @patil-suraj , is this better?
transformers/src/transformers/models/blenderbot/modeling_flax_blenderbot.py
Lines 720 to 722 in 415eb3c
(when it returns tuple + when it needs extra processing as in this PR, I always have trouble to make it cleaner. Things get much easier if the internal components return dict or named tuple, and only change the format at the top level components for the users - but I don't think we are going to do so, at least not soon)
There was a problem hiding this comment.
If this looks good, I will change other places.
There was a problem hiding this comment.
This looks good!
Things get much easier if the internal components return dict or named tuple, and only change the format at the top level components for the users
This is a good idea, I think we can change the internal modules to only return either dict or Tuple
There was a problem hiding this comment.
Thanks for the feedback :-) @patil-suraj . I will apply the same change to other places to finish this PR.
About changing internal components (in general), let's have a discussion later with other members.
| if not output_hidden_states: | ||
| return (last_hidden_states,) + outputs[1:] | ||
| else: | ||
| return (last_hidden_states, hidden_states) + outputs[2:] |
What does this PR do?
Fix some Flax models where the last element in
hidden_statesis different between PT/Flax version.More context
Some models have
In Pytorch version, the returned
hidden_stateshave thislast_hidden_state(after layer norm) as the last element.In Flax version, the last element of the returned
hidden_statesis the one before the layer norm.This PR fixes this inconsistency (by using the PyTorch logic).