Skip to content

Make Whisper Encoder's sinusoidal PE non-trainable by default#26032

Merged
sanchit-gandhi merged 27 commits intohuggingface:mainfrom
gau-nernst:whisper_encoder_pe
Oct 11, 2023
Merged

Make Whisper Encoder's sinusoidal PE non-trainable by default#26032
sanchit-gandhi merged 27 commits intohuggingface:mainfrom
gau-nernst:whisper_encoder_pe

Conversation

@gau-nernst
Copy link
Contributor

What does this PR do?

Fixes #25989

I'm not too familiar with Jax/Flax and can't find a simple way to set a variable a non-trainable in Flax. Do advise on how I should approach this.

Should we have a test for this behavior also? i.e. test that Whisper Encoder PE is non-trainable by default.

Another note. Should Encoder's positional encodings be initialized with sinusoids? Just like the official repo

https://github.com/openai/whisper/blob/main/whisper/model.py#L150

def sinusoids(length, channels, max_timescale=10000):
    """Returns sinusoids for positional embedding"""
    assert channels % 2 == 0
    log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
    inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
    scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@sanchit-gandhi

@sanchit-gandhi
Copy link
Contributor

Hey @gau-nernst - thanks very much for opening this PR! Looks like a great start already. I pushed the Flax changes in the latest commit. In short, the simplest way of setting the parameters to un-trainable in Flax is by stopping the back-prop through the layers. Otherwise, we need to explicitly pass a dict to the optimiser that defines which parameters are trainable/non-trainable (see https://colab.research.google.com/drive/1K-5bz6R6kt9GAvaUHvzYvvA-IOAO2PhL#scrollTo=BrF6Dtb8GlkJ)

There's not a test to check that the embed params are non-trainable, but you could certainly add one. This could follow the style of test that we use to check that we correctly freeze the encoder when we do decoder-only fine-tuning:

def test_requires_grad_with_frozen_encoder(self):

Regarding initialising the weights with sinusoidal embeddings - I agree that this should be the default case! In 99% of cases users will just use the model from pre-trained, in which case the embeddings will be initialised with the sinusoids, but if a user were to randomly initialise the model, the embeddings would be initialised incorrectly.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@gau-nernst
Copy link
Contributor Author

That's a great solution! However, from what I understand, it means that in the Flax implementation, it is not possible (or easy) to re-enable training for positional encodings? (something that we discussed previously)

@sanchit-gandhi
Copy link
Contributor

It's possible (but a bit involved) to add functionality to toggle whether we train the PE's in Flax. However, to me this PR is a bug fix, rather than a feature addition. I agree with what you said in the issue that we should not train the embeddings, since they used fixed sinusoidal embeddings in the original implementation, so I think it's fine if we do a straight fix and always freeze the embeddings here, since this is the correct behaviour.

@gau-nernst
Copy link
Contributor Author

I added sinusoids weight init for PyTorch implementation. Looking at TF and Flax, I'm not sure where to put weight init. It seems like there is no weight init code in TF? For Flax, I see this but don't really understand what's going on.

def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
input_features = jnp.zeros(input_shape, dtype="f4")
input_features = input_features.at[(..., -1)].set(self.config.eos_token_id)
decoder_input_ids = jnp.zeros((input_shape[0], 1), dtype="i4")
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
batch_size, sequence_length = decoder_input_ids.shape
decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
random_params = self.module.init(
rngs,
input_features=input_features,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
decoder_position_ids=decoder_position_ids,
)["params"]
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params

From other TF Keras and Flax code I have seen, I think the typical pattern is to pass weight init function to a module when it is created? I'm not sure what is the pattern HF is using here.

@sanchit-gandhi
Copy link
Contributor

The init_weights function in Flax is used to initialise the Flax model's parameters by passing a set of dummy inputs (zeros and ones). Flax traces out the shapes of the weights that you get when you pass these dummy inputs, and initialises weights with the right shapes accordingly (see https://flax.readthedocs.io/en/latest/guides/flax_basics.html#model-parameters-initialization).

This init_weights function won't actually change the values of the weights, just their shapes and dtypes. To change the initialising function, we can pass an argument embedding_init function to the init of the embedding layer: https://flax.readthedocs.io/en/latest/api_reference/flax.linen/_autosummary/flax.linen.Embed.html#flax.linen.Embed.embedding_init

The init function should be an instance of a JAX initialiser. That is, it should take the PRNG Key as the first argument, as well as the shape and target dtype of the module: https://jax.readthedocs.io/en/latest/jax.nn.initializers.html

@gau-nernst gau-nernst marked this pull request as ready for review September 23, 2023 03:38
Copy link
Contributor

@sanchit-gandhi sanchit-gandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your follow-up work on this issue @gau-nernst! Nice job on getting the Flax and TF parts working as well 👏 I've left some suggestions below on how we could potentially re-factor the code a bit to make it as clear as possible for the final PR, let me know if you have any questions!



# Copied from transformers.models.whisper.modeling_whisper.sinusoids
def sinusoids(length: int, channels: int, max_timescale: float = 10000) -> np.ndarray:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than defining this first function and then wrapping it with a very shallow second function embedding_init, I think it would be cleaner to define just one new function (sinusoidal_embedding_init) that takes three arguments:

  1. key: JAX PRNGKey (un-used, required to match the signature of the init function)
  2. shape: tuple of (length, channels, max_timescale)
  3. dtype: dtype of the computation

And returns the sinusoidal weights. To me, this would make the code a bit cleaner and easier to follow. How does this sound to you?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, that works as well. I used numpy initially to generate the sinusoids so that we can copy it across the 3 files and avoid errors. But having separate functions is fine by me too.

self.conv1 = tf.keras.layers.Conv1D(self.embed_dim, kernel_size=3, strides=1, padding="valid", name="conv1")
self.conv2 = tf.keras.layers.Conv1D(self.embed_dim, kernel_size=3, strides=2, padding="valid", name="conv2")

def embedding_init(shape, dtype=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably same here for TF too?


def sinusoids(length: int, channels: int, max_timescale: float = 10000) -> np.ndarray:
"""Returns sinusoids for positional embedding"""
assert channels % 2 == 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's try to avoid assert statements in favour of ValueErrors:

Suggested change
assert channels % 2 == 0
if channels % 2 != 0:
raise ValueError(f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels.")

module.weight.data.normal_(mean=0.0, std=std)
if not module.weight.requires_grad:
# sinusoidal positional encodings used in WhisperEncoder
with torch.no_grad():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this is safe - if we freeze the decoder embeddings then they'll get incorrectly initialised (since they'll be detected as requires_grad=False). Ideally, we need a way of just isolating the Encoder embeddings for this weight init. Do you think you could have a go at this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not 100% sure when _init_weights() is called. If it is always called at the end of __init__() (in post_init()?), and we don't freeze anything else in __init__(), it should still work as intended. However, I agree that relying on this behavior is error-prone and not exactly clean.

I don't think there is a clean way for _init_weights() to know an nn.Embedding() layer is from the Encoder? If that is the case, I think the best solution is to initialize sinusoids after _init_weights() is called within __init__()?

]


def sinusoids(length: int, channels: int, max_timescale: float = 10000) -> np.ndarray:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO cleaner to do this entire function in PyTorch for the PyTorch modelling file

Copy link
Contributor

@sanchit-gandhi sanchit-gandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for updating the sinusoids functions to their respective libraries! They look a lot better. Just a few small comments regarding the initialisation :)

remat = nn_partitioning.remat


def sinusoidal_embedding_init(max_timescale: float = 10000):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry @gau-nernst, can we not just define one function here? We should move away from defining two functions, where the outer one just calls the inner one directly

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I do it like this is to allow changing the default max_timescale value ("parameterized" function), so that it is on feature-parity with the PyTorch version. If we just define 1 function, there is no way to change max_timescale value, since Jax/Keras will only call the function with (shape, dtype) (and additionally RNG for Jax) (technically we can still bypass this by using functools.partial()). I follow Jax for this design (https://github.com/google/jax/blob/3247db774ea387098bd9d9049886030dc666cb39/jax/_src/nn/initializers.py#L133-L157). Another way is to make it a class (like Keras does).

Realistically the users won't be able to specify max_timescale to the model anyway since we don't expose it. So it would be fine for me to make max_timescale as a hard-coded constant also.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah if max_timescales is not a reachable argument by the user let's just hardcode it. We tend to do this anyway for sinusoidal embeddings:

def create_sinusoidal_positions(num_pos, dim):
inv_freq = 1.0 / (10000 ** (np.arange(0, dim, 2) / dim))

# Initialize weights and apply final processing
self.post_init()
with torch.no_grad():
self.embed_positions.weight.copy_(sinusoids(self.max_source_positions, embed_dim))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was more appropriate in _init_weights! Let's keep it there

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This still holds - can this go in _init_weights if possible?

@gau-nernst
Copy link
Contributor Author

@sanchit-gandhi I fixed the embedding init for TF and Flax as you requested. I also add test for TF and Flax. For Flax, I don't add a test for non-trainable sinusoidal embedding, because I don't know how to do it cleanly. For checking the weight init in Flax, I don't know Flax semantics so well, so I added a rather "crude" solution to get the encoder position embeddings.

Copy link
Contributor

@sanchit-gandhi sanchit-gandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice @gau-nernst - especially the Flax init which is really clean now 👌 Could the PT init go in _init_weights? Otherwise it all looks good to me!

remat = nn_partitioning.remat


def sinusoidal_embedding_init(key, shape, dtype=jnp.float_) -> jax.Array:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really nice!

hidden_states = jax.nn.gelu(self.conv2(hidden_states), approximate=False)

embed_positions = self.embed_positions(jnp.arange(self.config.max_source_positions))
# freeze the sinusoidal embeddings by stopping the back-prop
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to reviewer: by default we freeze the embeddings in Flax, and don't provide an override. See this explanation for detail: #26032 (comment)

max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")

def test_encoder_sinusoidal_embed_positions(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To test that the Flax embeddings are non-trainable (frozen), you can follow this Flax Wav2Vec2 test:

def test_freeze_feature_encoder(self):

To me this is optional: we know that the embeddings are initialised correctly through your test, and that grads are set to zero by action of jax.lax.stop_gradient, so up to you if you want to add this!

# Initialize weights and apply final processing
self.post_init()
with torch.no_grad():
self.embed_positions.weight.copy_(sinusoids(self.max_source_positions, embed_dim))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This still holds - can this go in _init_weights if possible?

gau-nernst and others added 2 commits October 3, 2023 09:00
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
@gau-nernst
Copy link
Contributor Author

Very nice @gau-nernst - especially the Flax init which is really clean now 👌 Could the PT init go in _init_weights? Otherwise it all looks good to me!

_init_weights() does not see the module name, it only sees the module itself. To make _init_weights() recognize encoder positional embedding, we probably need to set a private attribute to the module.

@sanchit-gandhi
Copy link
Contributor

Could we change the _init_weights logic to:

  1. Loop through all modules
  2. Check if module is encoder. If yes: loop through all the sub-modules. When we get to the pos embeddings, do the sinusoidal init
  3. Check if module is decoder. If yes: loop through all the sub-modules. When we get to the pos embeddings, do the normal init

@gau-nernst
Copy link
Contributor Author

I put sinusoidal init in _init_weights() but in a different way. Relying on the fact that nn.Module.apply() will traverse the children in a depth-first search manner (leaf modules will be applied first), if we check for Whisper encoder in _init_weights(), it will override the default initialization for positional embeddings.

    def apply(self: T, fn: Callable[['Module'], None]) -> T:
        ...
        for module in self.children():
            module.apply(fn)
        fn(self)
        return self

Copy link
Contributor

@sanchit-gandhi sanchit-gandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating here @gau-nernst and for the fruitful comment discussions! Requesting a TF review from @Rocketknight1 and maintainer review from @ArthurZucker. Thanks both!

LARGE_NEGATIVE = -1e8


def sinusoidal_embedding_init(shape, dtype=tf.float32) -> tf.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would appreciate a TF review here!

Copy link
Member

@Rocketknight1 Rocketknight1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TF code looks good to me! Doing it either this way or creating a tf.constant should both work.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your contribution! 😉

@sanchit-gandhi sanchit-gandhi merged commit 1e3c9dd into huggingface:main Oct 11, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Whisper Encoder's positional encodings shouldn't be trainable

5 participants