[FEAT] Add Neftune into transformers Trainer#27141
[FEAT] Add Neftune into transformers Trainer#27141younesbelkada merged 11 commits intohuggingface:mainfrom
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
muellerzr
left a comment
There was a problem hiding this comment.
Thanks! Overall this looks very good and handy to use. I left a few comments for an initial review :)
Co-authored-by: Zach Mueller <muellerzr@gmail.com>
|
Added a test and a relevant documentation section, this PR is ready for final review! |
amyeroberts
left a comment
There was a problem hiding this comment.
Nice work! 💪
Just some small comments. Main one is to add a check for the deactivation logic.
src/transformers/trainer.py
Outdated
| # After training we make sure to retrieve back the original forward pass method | ||
| # for the embedding layer by removing the forward post hook. | ||
| if self.neftune_noise_alpha is not None: | ||
| if is_peft_available() and isinstance(self.model, PeftModel): | ||
| embeddings = unwrap_model(self.model.base_model).get_input_embeddings() | ||
| else: | ||
| embeddings = unwrap_model(self.model).get_input_embeddings() | ||
|
|
||
| self.neftune_hook_handle.remove() | ||
| del embeddings.neftune_noise_alpha |
There was a problem hiding this comment.
Let's make this into an equivalent method _deacivate_neftune
src/transformers/trainer.py
Outdated
| if is_peft_available() and isinstance(self.model, PeftModel): | ||
| embeddings = unwrap_model(self.model.base_model).get_input_embeddings() | ||
| else: | ||
| embeddings = unwrap_model(self.model).get_input_embeddings() |
There was a problem hiding this comment.
Is this logic used anywhere else? It looks general enough that we could have a _get_model_input_embeddings function (not necessarily to be done in this PR)
There was a problem hiding this comment.
happy to refactor this in a follow up PR!
|
|
||
| # Make sure forward pass works fine | ||
| _ = trainer.model(torch.LongTensor([[1, 0, 1]]).to(torch_device)) | ||
| self.assertTrue(len(trainer.model.get_input_embeddings()._forward_hooks) == 0) |
There was a problem hiding this comment.
A check should be made that it's correctly disabled after training has finished
There was a problem hiding this comment.
the line
self.assertTrue(len(trainer.model.get_input_embeddings()._forward_hooks) == 0)Should check if the forward hook as been correctly removed so I think all should be good here
There was a problem hiding this comment.
Note also that line is called right after training, so it should check that neftune is correctly disabled after training.
There was a problem hiding this comment.
Added slightly more elaborated test in ca8f8c4
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
amyeroberts
left a comment
There was a problem hiding this comment.
Awesome - thanks for iterating!
| if not hasattr(self, "neftune_hook_handle"): | ||
| raise ValueError("Neftune is not activated make sure to call `trainer._activate_neftune()` first") |
* add v1 neftune * use `unwrap_model` instead * add test + docs * Apply suggestions from code review Co-authored-by: Zach Mueller <muellerzr@gmail.com> * more details * fixup * Update docs/source/en/main_classes/trainer.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * refactor a bit * more elaborated test * fix unwrap issue --------- Co-authored-by: Zach Mueller <muellerzr@gmail.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
What does this PR do?
As per title
Fixes: huggingface/trl#923
Fixes: #26899
This PR adds NEFTune: a new technique for enhancing Supervised fine-tuning results results proposed in: https://arxiv.org/abs/2310.05914
I propose a very simple API which is as simple as passing a valid
neftune_noise_alphaargument when initializing theTrainingArguments. To avoid any surprising behaviour, we should revert to the original forward method at the end of the training. This is handled inside the inner training loop that attaches the correct forward hook before the beginning of training, and makes sure to remove it right after training the model.