Skip to content

[SFTTrainer] Adds NEFTune into SFTTrainer#871

Merged
younesbelkada merged 5 commits intomainfrom
add-neftune
Oct 17, 2023
Merged

[SFTTrainer] Adds NEFTune into SFTTrainer#871
younesbelkada merged 5 commits intomainfrom
add-neftune

Conversation

@younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented Oct 13, 2023

What does this PR do?

This PR adds NEFTune: a new technique for enhancing Supervised fine-tuning results results proposed in: https://arxiv.org/abs/2310.05914

Screenshot 2023-10-13 at 17 36 38

I propose a very simple API which is as simple as passing a valid neftune_noise_alpha argument when initializing the SFTTrainer. To avoid any surprising behaviour, we should revert to the original forward method at the end of the training. This is handled inside def train() that is a wrapper around Trainer's train method.

from datasets import load_dataset
from trl import SFTTrainer

dataset = load_dataset("imdb", split="train")

trainer = SFTTrainer(
    "facebook/opt-350m",
    train_dataset=dataset,
    dataset_text_field="text",
    max_seq_length=512,
    neftune_noise_alpha=5,
)
trainer.train()

I still need to add few lines in the documentation.

Fixes: #870

cc @lvwerra @neelsjain @YuxinWenRick

Comment on lines +255 to +266
# After training we make sure to retrieve back the original forward pass method
# for the embedding layer
if self.neftune_noise_alpha is not None:

if isinstance(self.model, PreTrainedModel):
embeddings = self.model.get_input_embeddings()
elif isinstance(self.model, PeftModel):
embeddings = self.model.base_model.get_input_embeddings()

if hasattr(embeddings, "_trl_old_forward"):
embeddings.forward = embeddings._trl_old_forward
del embeddings._trl_old_forward
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we make sure to retrieve the original behaviour after training

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@younesbelkada younesbelkada requested a review from lvwerra October 13, 2023 16:19
@imrankh46
Copy link

i got this error

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
[<ipython-input-13-340218026309>](https://localhost:8080/#) in <cell line: 3>()
      1 from trl import SFTTrainer
      2 
----> 3 supervised_finetuning_trainer = SFTTrainer(
      4     base_model,
      5     train_dataset=formatted_dataset["train"],

TypeError: SFTTrainer.__init__() got an unexpected keyword argument 'neftune_noise_alpha'

@BenjaminBossan
Copy link
Member

I wonder if the implementation would be cleaner by using a post forward hook for the embedding layer instead of replacing the forward method completely.

@younesbelkada
Copy link
Contributor Author

younesbelkada commented Oct 15, 2023

@imrankh46 that feature is not merged yet in TRL main branch to use it please run:

pip install -U git+https://github.com/huggingface/trl.git@add-neftune

@BenjaminBossan , yes this is possible indeed and I think it is cleaner, however I found it easier to understand for future users to have a standalone forward method. Would it also hurt existing hooks that accelerate attaches in case we manipulate forward post hooks?

@BenjaminBossan
Copy link
Member

yes this is possible indeed and I think it is cleaner, however I found it easier to understand for future users to have a standalone forward method. Would it also hurt existing hooks that accelerate attaches in case we manipulate forward post hooks?

My thinking was that with a forward hook, you don't need to monkey patch (which can break stuff sometimes) and there is no need to call torch.nn.functional.embedding explicitly (what if the embedding layer of a model does something extra?). Whether it's easier to read or not, I don't know.

Existing hooks shouldn't be affected. When registering the hook, you get back a handle for this specific hook, which allows you to remove the hook once you don't need it anymore.

Copy link
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this amazing trick to boost SFT performance @younesbelkada 🔥 !

I'll leave the decision about hooks vs monkey patching to @lvwerra but otherwise this looks great to me.

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Copy link
Member

@lvwerra lvwerra left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me! 🚀

@cuongtran-uva
Copy link

@imrankh46 that feature is not merged yet in TRL main branch to use it please run:

pip install -U git+https://github.com/huggingface/trl.git@add-neftune

@BenjaminBossan , yes this is possible indeed and I think it is cleaner, however I found it easier to understand for future users to have a standalone forward method. Would it also hurt existing hooks that accelerate attaches in case we manipulate forward post hooks?

I tried to update the trl with neftune by using your command but it was unsuccessful.

Cloning https://github.com/huggingface/trl.git (to revision add-neftune) to /tmp/pip-req-build-copawfzq
  Running command git clone --quiet https://github.com/huggingface/trl.git /tmp/pip-req-build-copawfzq
  WARNING: Did not find branch or tag 'add-neftune', assuming revision or ref.
  Running command git checkout -q add-neftune
  error: pathspec 'add-neftune' did not match any file(s) known to git.
  error: subprocess-exited-with-error

@lvwerra
Copy link
Member

lvwerra commented Oct 31, 2023

The branch got deleted after merging. now you can use @main instead!

@daehuikim
Copy link

daehuikim commented Dec 11, 2023

The branch got deleted after merging. now you can use @main instead!

$ pip install -U git+https://github.com/huggingface/trl.git@main
This works for me. Thanks!

$ pip freeze | grep trl
trl @ git+https://github.com/huggingface/trl.git@(codes)

Installation check

lapp0 pushed a commit to lapp0/trl that referenced this pull request May 10, 2024
* v1 neftune

* docstring

* add doc + fix nit

* add more docs

* Apply suggestions from code review

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
yxliu-TAMU pushed a commit to mincheolseong/ECEN743-GRPO-Project-Proposal that referenced this pull request Apr 20, 2025
* v1 neftune

* docstring

* add doc + fix nit

* add more docs

* Apply suggestions from code review

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Integrate NEFT into SFTTrainer

8 participants