[SFTTrainer] Adds NEFTune into SFTTrainer#871
Conversation
| # After training we make sure to retrieve back the original forward pass method | ||
| # for the embedding layer | ||
| if self.neftune_noise_alpha is not None: | ||
|
|
||
| if isinstance(self.model, PreTrainedModel): | ||
| embeddings = self.model.get_input_embeddings() | ||
| elif isinstance(self.model, PeftModel): | ||
| embeddings = self.model.base_model.get_input_embeddings() | ||
|
|
||
| if hasattr(embeddings, "_trl_old_forward"): | ||
| embeddings.forward = embeddings._trl_old_forward | ||
| del embeddings._trl_old_forward |
There was a problem hiding this comment.
Here we make sure to retrieve the original behaviour after training
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
|
i got this error |
|
I wonder if the implementation would be cleaner by using a post forward hook for the embedding layer instead of replacing the |
|
@imrankh46 that feature is not merged yet in TRL main branch to use it please run: pip install -U git+https://github.com/huggingface/trl.git@add-neftune@BenjaminBossan , yes this is possible indeed and I think it is cleaner, however I found it easier to understand for future users to have a standalone forward method. Would it also hurt existing hooks that accelerate attaches in case we manipulate forward post hooks? |
My thinking was that with a forward hook, you don't need to monkey patch (which can break stuff sometimes) and there is no need to call Existing hooks shouldn't be affected. When registering the hook, you get back a handle for this specific hook, which allows you to remove the hook once you don't need it anymore. |
lewtun
left a comment
There was a problem hiding this comment.
Thanks for adding this amazing trick to boost SFT performance @younesbelkada 🔥 !
I'll leave the decision about hooks vs monkey patching to @lvwerra but otherwise this looks great to me.
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
I tried to update the trl with neftune by using your command but it was unsuccessful. |
|
The branch got deleted after merging. now you can use |
Installation check |
* v1 neftune * docstring * add doc + fix nit * add more docs * Apply suggestions from code review Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> --------- Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
* v1 neftune * docstring * add doc + fix nit * add more docs * Apply suggestions from code review Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> --------- Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
What does this PR do?
This PR adds NEFTune: a new technique for enhancing Supervised fine-tuning results results proposed in: https://arxiv.org/abs/2310.05914
I propose a very simple API which is as simple as passing a valid
neftune_noise_alphaargument when initializing theSFTTrainer. To avoid any surprising behaviour, we should revert to the original forward method at the end of the training. This is handled insidedef train()that is a wrapper aroundTrainer'strainmethod.I still need to add few lines in the documentation.
Fixes: #870
cc @lvwerra @neelsjain @YuxinWenRick