[Wav2Vec2] Fix normalization for non-padded tensors#13512
[Wav2Vec2] Fix normalization for non-padded tensors#13512patrickvonplaten merged 7 commits intohuggingface:masterfrom
Conversation
|
|
||
| @staticmethod | ||
| def zero_mean_unit_var_norm(input_values: List[np.ndarray], input_lengths: List[int]) -> List[np.ndarray]: | ||
| def zero_mean_unit_var_norm( |
There was a problem hiding this comment.
The responsibility of retrieving the correct length from the attention mask should be in this method since input_values and attention_mask are the well-known inputs to functions in transformers
| return_attention_mask=return_attention_mask, | ||
| ) | ||
|
|
||
| if "attention_mask" in padded_inputs: |
There was a problem hiding this comment.
This part is removed/cleaned-up
| and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list))) | ||
| ) | ||
|
|
||
| # make sure input is in list format |
There was a problem hiding this comment.
Currently all the padding is happening in pure python and not in numpy so let's move the numpification further down
| _check_zero_mean_unit_variance(input_values[2]) | ||
|
|
||
| def test_zero_mean_unit_variance_normalization_trunc(self): | ||
| def test_zero_mean_unit_variance_normalization(self): |
There was a problem hiding this comment.
Add test to make sure normalization always works as expected
patil-suraj
left a comment
There was a problem hiding this comment.
Great catch!
This looks good to me.
LysandreJik
left a comment
There was a problem hiding this comment.
Seems to look good but will delegate to @patil-suraj and @anton-l's w2v2 knowledge.
Let me know once this is merged so that I may release a patch.
anton-l
left a comment
There was a problem hiding this comment.
LGTM other than the small issues already pointed out, thanks for fixing it!
anton-l
left a comment
There was a problem hiding this comment.
All slow tests now pass for Wav2Vec and Hubert, nice!
patil-suraj
left a comment
There was a problem hiding this comment.
LGTM! Thanks for adding all those tests :)
* finalize * Apply suggestions from code review * finish cleaner implementation * more tests * small fix * finish * up
What does this PR do?
This PR fixes a problem with normalization when the input is a list of different length that is not numpified - see: #13504
Just noticed that this bug is pretty severe actually as it affects all large-Wav2Vec2 fine-tuning :-/.
It was introduced by me in this PR: https://github.com/huggingface/transformers/pull/12804/files - I should have written more and better tests for this.
=> This means that from transformers 4.9.0 to until this PR is merged the normalization for all large Wav2Vec2 models was way off when fine-tuning the model.
@LysandreJik - do you think it might be possible to do a patched release for this?