Skip to content

[Wav2Vec2] Fix normalization for non-padded tensors#13512

Merged
patrickvonplaten merged 7 commits intohuggingface:masterfrom
patrickvonplaten:fix_normalization_non_padded
Sep 10, 2021
Merged

[Wav2Vec2] Fix normalization for non-padded tensors#13512
patrickvonplaten merged 7 commits intohuggingface:masterfrom
patrickvonplaten:fix_normalization_non_padded

Conversation

@patrickvonplaten
Copy link
Copy Markdown
Contributor

@patrickvonplaten patrickvonplaten commented Sep 10, 2021

What does this PR do?

This PR fixes a problem with normalization when the input is a list of different length that is not numpified - see: #13504

Just noticed that this bug is pretty severe actually as it affects all large-Wav2Vec2 fine-tuning :-/.
It was introduced by me in this PR: https://github.com/huggingface/transformers/pull/12804/files - I should have written more and better tests for this.

=> This means that from transformers 4.9.0 to until this PR is merged the normalization for all large Wav2Vec2 models was way off when fine-tuning the model.

@LysandreJik - do you think it might be possible to do a patched release for this?

Comment thread src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py Outdated
Comment thread src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py Outdated
Comment thread src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py Outdated

@staticmethod
def zero_mean_unit_var_norm(input_values: List[np.ndarray], input_lengths: List[int]) -> List[np.ndarray]:
def zero_mean_unit_var_norm(
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The responsibility of retrieving the correct length from the attention mask should be in this method since input_values and attention_mask are the well-known inputs to functions in transformers

return_attention_mask=return_attention_mask,
)

if "attention_mask" in padded_inputs:
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is removed/cleaned-up

and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list)))
)

# make sure input is in list format
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently all the padding is happening in pure python and not in numpy so let's move the numpification further down

_check_zero_mean_unit_variance(input_values[2])

def test_zero_mean_unit_variance_normalization_trunc(self):
def test_zero_mean_unit_variance_normalization(self):
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add test to make sure normalization always works as expected

Copy link
Copy Markdown
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch!

This looks good to me.

Comment thread src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py Outdated
Comment thread tests/test_feature_extraction_wav2vec2.py Outdated
Comment thread src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py Outdated
Copy link
Copy Markdown
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to look good but will delegate to @patil-suraj and @anton-l's w2v2 knowledge.

Let me know once this is merged so that I may release a patch.

Comment thread src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py
Copy link
Copy Markdown
Member

@anton-l anton-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM other than the small issues already pointed out, thanks for fixing it!

Comment thread src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py
Comment thread tests/test_feature_extraction_wav2vec2.py
Comment thread tests/test_feature_extraction_speech_to_text.py
Copy link
Copy Markdown
Member

@anton-l anton-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All slow tests now pass for Wav2Vec and Hubert, nice!

Copy link
Copy Markdown
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for adding all those tests :)

@patrickvonplaten patrickvonplaten merged commit d7b3b70 into huggingface:master Sep 10, 2021
@patrickvonplaten patrickvonplaten deleted the fix_normalization_non_padded branch September 10, 2021 13:27
patrickvonplaten added a commit that referenced this pull request Sep 10, 2021
* finalize

* Apply suggestions from code review

* finish cleaner implementation

* more tests

* small fix

* finish

* up
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Wav2vec2Processor normalization issues on transformers 4.10.0

4 participants