[Flax] Implement FlaxElectraModel, FlaxElectraForMaskedLM, FlaxElectraForPreTraining#9172
[Flax] Implement FlaxElectraModel, FlaxElectraForMaskedLM, FlaxElectraForPreTraining#9172chris-tng wants to merge 14 commits intohuggingface:mainfrom
Conversation
…n model output, ranging 0.0010 - 0.0016
| Returns: | ||
| Normalized inputs (the same shape as inputs). | ||
| """ | ||
| features = x.shape[-1] |
There was a problem hiding this comment.
can we replace x and y by hidden_states?
There was a problem hiding this comment.
The PR looks great to me!
I think your argument for using setup instead of nn.compact is a good one and we should probably replace all usage of nn.compact to setup.
In addition to @chris-tng arguments for using setup instead of nn.compact (easier to test and makes a shorter, more concise forward function), I think a third argument is also that it'll make the class' signature more similar to PyTorch and TF. Transformers users would probably have an easier time understanding what is happening when setup is implemented vs. nn.compact.
I'd be in support of replacing all nn.compacts with the setup function.
What do you think @sgugger @mfuntowicz? I don't really see an advantage in using nn.compact over setup.
sgugger
left a comment
There was a problem hiding this comment.
I'm no Flax expert (yet) but this looks good to me! For the difference between nn.compact and setup I really don't know enough to be able to weigh in.
| Args: | ||
| x: the inputs | ||
|
|
||
| Returns: | ||
| Normalized inputs (the same shape as inputs). |
There was a problem hiding this comment.
Please use 4 spaces for indentation :-)
| class FlaxElectraPooler(nn.Module): | ||
| kernel_init_scale: float = 0.2 | ||
| dtype: jnp.dtype = jnp.float32 # the dtype of the computation | ||
|
|
||
| @nn.compact | ||
| def __call__(self, hidden_states): | ||
| cls_token = hidden_states[:, 0] | ||
| out = nn.Dense( | ||
| hidden_states.shape[-1], | ||
| kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype), | ||
| name="dense", | ||
| dtype=self.dtype, | ||
| )(cls_token) | ||
| return nn.tanh(out) |
There was a problem hiding this comment.
I don't think we have this in the PyTorch/TF versions? And looking at the file it doesn't seem to be used anywhere.
|
Hi @chris-tng -- thanks for trying out Flax in HF Transformers! A quick comment on Indeed (Please do let us know whatever other thoughts or questions on Flax on our discussion board: https://github.com/google/flax/discussions) Happy holidays and new year! |
Hey @avital, Thanks a lot for your input here! That's very useful. Most of the main contributors to Transformers are on holiday at the moment and this is a rather big design decision to make going forward with Flax, so I think we'll have to wait here until early January until everybody is back (@sgugger, @LysandreJik, @mfuntowicz) Happy holiday to you as well :-) |
|
Hi @avital , Apology for my delayed response. I appreciate your great work on Flax. Regarding the use of class Dummy(nn.Module):
def setup(self):
self.submodule1 = nn.Dense(10)
self.submodule2 = MyLayerNorm()
def __call__(self):
# do something hereAfter loading model weights from a dict, I can access/debug submodule by simply accessing the attribute: Shameless plug, I wrote a blog post about porting huggingface pytorch model to flax, here. I'm a new Flax user so please correct me if I'm missing anything. Happy holiday and happy new year to everyone 🎄 🍾 |
|
Hey @chris-tng, sorry to had you wait for this long. I'll solve the merge conflicts in your PR and then use your PR to change the
|
|
Intermediate state is saved here: #9484 will push to this PR on Monday the latest |
|
Hey @chris-tng, I noticed that we will probably have to wait a bit to get this merged: google/flax#683 to be able to continue the PR. Will keep you up-to-date :-) |
|
Hi folks, sorry for the delay with the new-year shuffle and school shutdown. google/flax#683 required a bit more conversation and updating some other codebases but now it's merged! If you have a moment, please take a look and see if it helps unblock progress. We'll release Flax 0.4.0 soon, but installing from GitHub now is the way to go. |
|
Hey, sorry for barging in |
Hey @CoderPat, It would be great if you could wait until #11364 is merged (should be done in the next 2 days). The PR fixes a couple of bugs :-) |
|
No problem @patrickvonplaten! Also regarding git logistics, is it better to ask @chris-tng for permission to push directly to his branch? |
I think it's alright to copy past the code that is still useful and open a new branch, if you'd like to add Electra :-). On the branch we should then give credit to @chris-tng , but since the PR is quite old now I think he would be fine if we close this one and open a new one (Please let me know if this is not the case @chris-tng :-)) . #11364 should be the last refactor before the "fundamental" Flax design is finished. |
|
Just to confirm @patrickvonplaten , the flax refactor is merged and the structure should be stable enough that I can work on implementing Electra right? |
|
Exactly @CoderPat - very much looking forward to your PR :-) |
|
Closing as this PR is super old and partly fixed by #11426 |
What does this PR do?
FlaxElectraModel,FlaxElectraForMaskedLM,FlaxElectraForPreTraining. Most of the code taken from FlaxBert version with changes in parameters and forward pass.convert_to_pytorchto load weights for ElectraFlaxElectraGeneratorPredictions,FlaxElectraDiscriminatorPredictionsfor generator and discriminator prediction head.tests/test_modeling_flax_electra.pyForward pass works by running
Hi @patrickvonplaten , @mfuntowicz , I've seen your work on FlaxBert, so I'm tagging in case you want to review. Please note that I use
flax setupinstead of decorator@nn.compactsince the former@nn.compactI'm happy to revert this change to make code style consistent.
Let me know if you have any questions or feedbacks.
Thanks.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
tests/test_modeling_flax_electra.pyWho can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors which may be interested in your PR.