Make Whisper Encoder's sinusoidal PE non-trainable by default#26032
Make Whisper Encoder's sinusoidal PE non-trainable by default#26032sanchit-gandhi merged 27 commits intohuggingface:mainfrom gau-nernst:whisper_encoder_pe
Conversation
|
Hey @gau-nernst - thanks very much for opening this PR! Looks like a great start already. I pushed the Flax changes in the latest commit. In short, the simplest way of setting the parameters to un-trainable in Flax is by stopping the back-prop through the layers. Otherwise, we need to explicitly pass a dict to the optimiser that defines which parameters are trainable/non-trainable (see https://colab.research.google.com/drive/1K-5bz6R6kt9GAvaUHvzYvvA-IOAO2PhL#scrollTo=BrF6Dtb8GlkJ) There's not a test to check that the embed params are non-trainable, but you could certainly add one. This could follow the style of test that we use to check that we correctly freeze the encoder when we do decoder-only fine-tuning: Regarding initialising the weights with sinusoidal embeddings - I agree that this should be the default case! In 99% of cases users will just use the model from pre-trained, in which case the embeddings will be initialised with the sinusoids, but if a user were to randomly initialise the model, the embeddings would be initialised incorrectly. |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
|
That's a great solution! However, from what I understand, it means that in the Flax implementation, it is not possible (or easy) to re-enable training for positional encodings? (something that we discussed previously) |
|
It's possible (but a bit involved) to add functionality to toggle whether we train the PE's in Flax. However, to me this PR is a bug fix, rather than a feature addition. I agree with what you said in the issue that we should not train the embeddings, since they used fixed sinusoidal embeddings in the original implementation, so I think it's fine if we do a straight fix and always freeze the embeddings here, since this is the correct behaviour. |
|
I added sinusoids weight init for PyTorch implementation. Looking at TF and Flax, I'm not sure where to put weight init. It seems like there is no weight init code in TF? For Flax, I see this but don't really understand what's going on. transformers/src/transformers/models/whisper/modeling_flax_whisper.py Lines 865 to 895 in 0a55d9f From other TF Keras and Flax code I have seen, I think the typical pattern is to pass weight init function to a module when it is created? I'm not sure what is the pattern HF is using here. |
|
The This The init function should be an instance of a JAX initialiser. That is, it should take the PRNG Key as the first argument, as well as the shape and target dtype of the module: https://jax.readthedocs.io/en/latest/jax.nn.initializers.html |
sanchit-gandhi
left a comment
There was a problem hiding this comment.
Thanks for your follow-up work on this issue @gau-nernst! Nice job on getting the Flax and TF parts working as well 👏 I've left some suggestions below on how we could potentially re-factor the code a bit to make it as clear as possible for the final PR, let me know if you have any questions!
|
|
||
|
|
||
| # Copied from transformers.models.whisper.modeling_whisper.sinusoids | ||
| def sinusoids(length: int, channels: int, max_timescale: float = 10000) -> np.ndarray: |
There was a problem hiding this comment.
Rather than defining this first function and then wrapping it with a very shallow second function embedding_init, I think it would be cleaner to define just one new function (sinusoidal_embedding_init) that takes three arguments:
key: JAX PRNGKey (un-used, required to match the signature of the init function)shape: tuple of(length, channels, max_timescale)dtype: dtype of the computation
And returns the sinusoidal weights. To me, this would make the code a bit cleaner and easier to follow. How does this sound to you?
There was a problem hiding this comment.
Sure, that works as well. I used numpy initially to generate the sinusoids so that we can copy it across the 3 files and avoid errors. But having separate functions is fine by me too.
| self.conv1 = tf.keras.layers.Conv1D(self.embed_dim, kernel_size=3, strides=1, padding="valid", name="conv1") | ||
| self.conv2 = tf.keras.layers.Conv1D(self.embed_dim, kernel_size=3, strides=2, padding="valid", name="conv2") | ||
|
|
||
| def embedding_init(shape, dtype=None): |
There was a problem hiding this comment.
Probably same here for TF too?
|
|
||
| def sinusoids(length: int, channels: int, max_timescale: float = 10000) -> np.ndarray: | ||
| """Returns sinusoids for positional embedding""" | ||
| assert channels % 2 == 0 |
There was a problem hiding this comment.
Let's try to avoid assert statements in favour of ValueErrors:
| assert channels % 2 == 0 | |
| if channels % 2 != 0: | |
| raise ValueError(f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels.") |
| module.weight.data.normal_(mean=0.0, std=std) | ||
| if not module.weight.requires_grad: | ||
| # sinusoidal positional encodings used in WhisperEncoder | ||
| with torch.no_grad(): |
There was a problem hiding this comment.
I'm not sure this is safe - if we freeze the decoder embeddings then they'll get incorrectly initialised (since they'll be detected as requires_grad=False). Ideally, we need a way of just isolating the Encoder embeddings for this weight init. Do you think you could have a go at this?
There was a problem hiding this comment.
I'm not 100% sure when _init_weights() is called. If it is always called at the end of __init__() (in post_init()?), and we don't freeze anything else in __init__(), it should still work as intended. However, I agree that relying on this behavior is error-prone and not exactly clean.
I don't think there is a clean way for _init_weights() to know an nn.Embedding() layer is from the Encoder? If that is the case, I think the best solution is to initialize sinusoids after _init_weights() is called within __init__()?
| ] | ||
|
|
||
|
|
||
| def sinusoids(length: int, channels: int, max_timescale: float = 10000) -> np.ndarray: |
There was a problem hiding this comment.
IMO cleaner to do this entire function in PyTorch for the PyTorch modelling file
sanchit-gandhi
left a comment
There was a problem hiding this comment.
Thanks for updating the sinusoids functions to their respective libraries! They look a lot better. Just a few small comments regarding the initialisation :)
| remat = nn_partitioning.remat | ||
|
|
||
|
|
||
| def sinusoidal_embedding_init(max_timescale: float = 10000): |
There was a problem hiding this comment.
Sorry @gau-nernst, can we not just define one function here? We should move away from defining two functions, where the outer one just calls the inner one directly
There was a problem hiding this comment.
The reason I do it like this is to allow changing the default max_timescale value ("parameterized" function), so that it is on feature-parity with the PyTorch version. If we just define 1 function, there is no way to change max_timescale value, since Jax/Keras will only call the function with (shape, dtype) (and additionally RNG for Jax) (technically we can still bypass this by using functools.partial()). I follow Jax for this design (https://github.com/google/jax/blob/3247db774ea387098bd9d9049886030dc666cb39/jax/_src/nn/initializers.py#L133-L157). Another way is to make it a class (like Keras does).
Realistically the users won't be able to specify max_timescale to the model anyway since we don't expose it. So it would be fine for me to make max_timescale as a hard-coded constant also.
There was a problem hiding this comment.
Yeah if max_timescales is not a reachable argument by the user let's just hardcode it. We tend to do this anyway for sinusoidal embeddings:
transformers/src/transformers/models/gptj/modeling_flax_gptj.py
Lines 109 to 110 in 3911774
| # Initialize weights and apply final processing | ||
| self.post_init() | ||
| with torch.no_grad(): | ||
| self.embed_positions.weight.copy_(sinusoids(self.max_source_positions, embed_dim)) |
There was a problem hiding this comment.
This was more appropriate in _init_weights! Let's keep it there
There was a problem hiding this comment.
This still holds - can this go in _init_weights if possible?
|
@sanchit-gandhi I fixed the embedding init for TF and Flax as you requested. I also add test for TF and Flax. For Flax, I don't add a test for non-trainable sinusoidal embedding, because I don't know how to do it cleanly. For checking the weight init in Flax, I don't know Flax semantics so well, so I added a rather "crude" solution to get the encoder position embeddings. |
sanchit-gandhi
left a comment
There was a problem hiding this comment.
Very nice @gau-nernst - especially the Flax init which is really clean now 👌 Could the PT init go in _init_weights? Otherwise it all looks good to me!
| remat = nn_partitioning.remat | ||
|
|
||
|
|
||
| def sinusoidal_embedding_init(key, shape, dtype=jnp.float_) -> jax.Array: |
| hidden_states = jax.nn.gelu(self.conv2(hidden_states), approximate=False) | ||
|
|
||
| embed_positions = self.embed_positions(jnp.arange(self.config.max_source_positions)) | ||
| # freeze the sinusoidal embeddings by stopping the back-prop |
There was a problem hiding this comment.
Note to reviewer: by default we freeze the embeddings in Flax, and don't provide an override. See this explanation for detail: #26032 (comment)
| max_diff = (base_params[key] - base_params_from_head[key]).sum().item() | ||
| self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") | ||
|
|
||
| def test_encoder_sinusoidal_embed_positions(self): |
There was a problem hiding this comment.
To test that the Flax embeddings are non-trainable (frozen), you can follow this Flax Wav2Vec2 test:
To me this is optional: we know that the embeddings are initialised correctly through your test, and that grads are set to zero by action of jax.lax.stop_gradient, so up to you if you want to add this!
| # Initialize weights and apply final processing | ||
| self.post_init() | ||
| with torch.no_grad(): | ||
| self.embed_positions.weight.copy_(sinusoids(self.max_source_positions, embed_dim)) |
There was a problem hiding this comment.
This still holds - can this go in _init_weights if possible?
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
|
|
Could we change the
|
|
I put sinusoidal init in def apply(self: T, fn: Callable[['Module'], None]) -> T:
...
for module in self.children():
module.apply(fn)
fn(self)
return self |
sanchit-gandhi
left a comment
There was a problem hiding this comment.
Thanks for iterating here @gau-nernst and for the fruitful comment discussions! Requesting a TF review from @Rocketknight1 and maintainer review from @ArthurZucker. Thanks both!
| LARGE_NEGATIVE = -1e8 | ||
|
|
||
|
|
||
| def sinusoidal_embedding_init(shape, dtype=tf.float32) -> tf.Tensor: |
There was a problem hiding this comment.
Would appreciate a TF review here!
Rocketknight1
left a comment
There was a problem hiding this comment.
TF code looks good to me! Doing it either this way or creating a tf.constant should both work.
ArthurZucker
left a comment
There was a problem hiding this comment.
Thanks for your contribution! 😉
What does this PR do?
Fixes #25989
I'm not too familiar with Jax/Flax and can't find a simple way to set a variable a non-trainable in Flax. Do advise on how I should approach this.
Should we have a test for this behavior also? i.e. test that Whisper Encoder PE is non-trainable by default.
Another note. Should Encoder's positional encodings be initialized with sinusoids? Just like the official repo
https://github.com/openai/whisper/blob/main/whisper/model.py#L150
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@sanchit-gandhi