Skip to content

Added implementation for the LEAF audio frontend#1364

Merged
mravanelli merged 15 commits intospeechbrain:developfrom
SarthakYadav:leaf
Jun 24, 2022
Merged

Added implementation for the LEAF audio frontend#1364
mravanelli merged 15 commits intospeechbrain:developfrom
SarthakYadav:leaf

Conversation

@SarthakYadav
Copy link
Contributor

This PR adds an implementation for the LEAF [1] audio frontend. Following is a summary of changes:

  1. Added GaborConv1d [1] layer to speechbrain.nnet.CNN.py
  2. Added learnable Gaussian lowpass pooling layer [1] to speechbrain.nnet.pooling.py
  3. Added Per-channel energy normalization layer [1,2] to speechbrain.nnet.normalisation.py. Includes dependency ExponentialMovingAverage in speechbrain.nnet.ema.py
  4. Added full LEAF frontend [1] to speechbrain.nnet.CNN.py

References

[1] Neil Zeghidour, Olivier Teboul, F{'e}lix de Chaumont Quitry & Marco Tagliasacchi, "LEAF: A LEARNABLE FRONTEND FOR AUDIO CLASSIFICATION", in Proc. of ICLR 2021 online
[2] Yuxuan Wang, Pascal Getreuer, Thad Hughes, Richard F. Lyon, Rif A. Saurous, "Trainable Frontend For Robust and Far-Field Keyword Spotting", in Proc of ICASSP 2017 online

@SarthakYadav
Copy link
Contributor Author

@TParcollet Take a look.

@TParcollet
Copy link
Collaborator

Thank you so much!

@SarthakYadav
Copy link
Contributor Author

SarthakYadav commented Apr 8, 2022

@TParcollet I just added fix for the pre-commit fail.

@anautsch
Copy link
Collaborator

anautsch commented Apr 13, 2022

@SarthakYadav black points to some (minor) formatting edits, please take a look.

@SarthakYadav
Copy link
Contributor Author

@anautsch Done had to take upstream changes from speechbrain:develop to get black to work correctly. Should pass now.

@anautsch
Copy link
Collaborator

Hi @SarthakYadav sorry for the delay on my end.

Needed to start the checks manually, and there are errors with the doctest:

FAILED speechbrain/nnet/CNN.py::speechbrain.nnet.CNN.GaborConv1d
FAILED speechbrain/nnet/CNN.py::speechbrain.nnet.CNN.Leaf

You can check all doctests with: pytest --doctest-modules speechbrain or this particular one:
pytest --doctest-modules speechbrain/nnet/CNN.py

The testing on git runs these scripts, which you can also try on your machine:

./tests/.run-linters.sh
./tests/.run-doctests.sh
./tests/.run-unittests.sh

Crossing fingers it's a little bug only!

@SarthakYadav
Copy link
Contributor Author

Hi @anautsch

Sorry for the delay on my end.

No worries!
I've made the relevant changes in the latest commit. Have also added a unit test for Leaf. It should work fine now! Please take a look

@SarthakYadav
Copy link
Contributor Author

Hi @anautsch. I just resolved a tiny merge conflict. Can you give the workflows approval again?

@anautsch
Copy link
Collaborator

Hi @SarthakYadav yes, no worries - your PR lgtm - @TParcollet suggested that we run the code on our side once more and merge then

@SarthakYadav
Copy link
Contributor Author

Great, sounds good!

@mravanelli
Copy link
Collaborator

@anautsch, any news on that?

@mravanelli mravanelli added the ready to review Waiting on reviewer to provide feedback label May 18, 2022
@mravanelli mravanelli requested a review from anautsch May 18, 2022 19:24
@TParcollet
Copy link
Collaborator

@anautsch and I reviewed it. Now, I must find the time to test it ...

@TParcollet
Copy link
Collaborator

I will certainly have to review it again ...

|----------------- | ------------ |
| xvector + augment v12 | 98.14% |
| xvector + augment v35 | 97.43% |
| xvector + augment + LEAF v35 | 96.79% |
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

results are worst ?

Copy link
Contributor Author

@SarthakYadav SarthakYadav May 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. That's what I got in the first and only experiment. Leaf was evaluated on EfficientNetB0 and CNN14 architectures, so I have no known xvector baselines to go by.

return denominator * sinusoid * gaussian


def gabor_impulse_response_legacy_complex(t, center, fwhm):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Legacy ?

Copy link
Contributor Author

@SarthakYadav SarthakYadav May 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Leaf internally has some complex dtype operations, and I used to face problems with these operations in prior versions of torch (as well as in torch-xla, which to my best knowledge still doesn't support grad on those ops on TPUs). _legacy_complex is basically doing these operations as two float tensors instead of a complex dtype tensor.

This is also explained in the docs for LEAF/GaborConv1d

They can be removed if you like. But some people who might need to use a prev torch version (say <=1.9) for different reasons might find this extremely helpful, and it's controlled here using a simple boolean flag. Your call!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any updates on this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All good.

@anautsch anautsch changed the base branch from develop to develop-v2 June 1, 2022 15:39
return in_channels


class Leaf(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @SarthakYadav, shouldn't this be a lobes instead ? I see it as a "complex" composition rather than a building block.

Copy link
Contributor Author

@SarthakYadav SarthakYadav Jun 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, makes sense. I simply followed SincNet (which was a Module). I'll make it a lobe, and move it to speechbrain.lobes.features

return int(padding)


def gabor_impulse_response(t, center, fwhm):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering if these functions shouldn't go somewhere else, as they are not related to NN stuff, but more DSP ? what about Speechbrain.processing.features or speechbrain.process.signal_processing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. I'll move them to speechbrain.process.signal_processing

from torch import nn


class ExponentialMovingAverage(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this go to speechbrain.nnet.normalization? That is literally a question ahah. The idea always is to reduce the number of files.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well the idea was that EMA might find other use cases. But I'll move it to speechbrain.nnet.normalization, it goes well there too.

@TParcollet
Copy link
Collaborator

Once the comments have been addressed, I'll merge. I tested, and it works :-) Thanks for the huge work.

@mravanelli
Copy link
Collaborator

Also @SarthakYadav, could you please merge the latest version of the development here? We recently added many consistency tests that helps making sure the code is fine.

@SarthakYadav
Copy link
Contributor Author

Once the comments have been addressed, I'll merge. I tested, and it works :-) Thanks for the huge work.

Also @SarthakYadav, could you please merge the latest version of the development here? We recently added many consistency tests that helps making sure the code is fine.

Sure @mravanelli, will do.

@SarthakYadav
Copy link
Contributor Author

SarthakYadav commented Jun 22, 2022

The latest commit incorporates all the suggestions. Have also updated the sample recipe, training is working.

@TParcollet Please take a look.

@TParcollet
Copy link
Collaborator

@SarthakYadav some tests are failing, I will let you fix that and then we merge ! I am fine with the code now :-)

@SarthakYadav
Copy link
Contributor Author

@SarthakYadav some tests are failing, I will let you fix that and then we merge ! I am fine with the code now :-)

It was a failed recipe consistency test due to my .yaml not being in test/recipes.csv. I've updated it.

@SarthakYadav
Copy link
Contributor Author

@mravanelli thanks for fixing the recipe tests. I was about to post asking how to do that.

It seems to me that it's failing due to no documentation for the forward methods in the modules I wrote? I'll fix that soon.

@mravanelli
Copy link
Collaborator

mravanelli commented Jun 24, 2022 via email

@mravanelli
Copy link
Collaborator

I finally wrote the missing docstrings (we are accelerating a bit because we will release the new version of speechbrain soon). If all the tests pass, I think we can merge it!

@mravanelli
Copy link
Collaborator

thank you @SarthakYadav for this great job! Of course you are welcome to keep contributing to speechbrain if you want.

@mravanelli mravanelli merged commit a512bb6 into speechbrain:develop Jun 24, 2022
@SarthakYadav
Copy link
Contributor Author

I finally wrote the missing docstrings (we are accelerating a bit because we will release the new version of speechbrain soon). If all the tests pass, I think we can merge it!

thank you @SarthakYadav for this great job! Of course you are welcome to keep contributing to speechbrain if you want.

Thanks a lot @mravanelli!

@SarthakYadav SarthakYadav deleted the leaf branch June 26, 2022 09:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready to review Waiting on reviewer to provide feedback

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants