Skip to content

hanshen95/SEAL

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

15 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

SEAL logo

Code License Python 3.10+ Arxiv link

SEAL is an LLM fine-tuning framework with safety-enhancing data selection (see paper). This implementation is based on OpenRLHF, DeepSpeed, Transformers and Pytorch.

Introduction

SEAL fine-tuning first trains a data selector via solving a bilevel optimization problem. Then it filters the fine-tuing dataset with the trained selector by hard-thresholding. Finally we fine-tune the LLM on the filtered dataset.

SEAL framework


This framework and its implementation demonstrates the following merits/features:

  • Effective: We evaluate SEAL on test datasets including Anthropic HH, Slim Orca and HEx-PHI. SEAL consistently outperforms multiple baselines across different models including Llama-3-8b-Instruct, Merlinite-7b and Pythia-2.8b.

  • Flexible and transferable: The performance is relatively robust to data selection percent, and the trained selector can be transferable between fine tuning different models.

  • Distributed training: This implementation is based on OpenRLHF, which uses DeepSpeed for efficient distributed training and Transformers for easy modification capability.

Example Results

Evaluation metric and datasets. We follow AlpacaEval and use the win rate over test dataset to evaluate the quality of the model. We evaluate the win rate over Anthropic HH, Slim Orca and HEx-PHI test prompts.

We give an example on the Llama-3-8b-Instruct model as follows.



seal llama3


Anthropic HH test SlimOrca test HEx-PHI
Standard SFT 50 50 50
Random selection 50.78 50.8 56.31
DSIR 57.57 55.84 53.95
SafeInstr 57.97 54.22 64.49
SEAL 60.22 53.88 69.29
SEAL+SafeInstr 67.19 53.91 77.28

Installation

Create conda environment

conda create -n seal python=3.10
conda activate seal

We recommend installing pytorch compatible with your machine first:

pip install torch==<your version>

To install other denpendencies, navigate to the root directory and

pip install -r requirements.txt 

torch==2.2.2 is compatible with the package version specified in "requirements.txt" for our machine.

Then install the SEAL

pip install -e .

Running Example

Navigate to scripts folder

cd examples/scripts

Data selector training

In SEAL, we first train a data selector, e.g., with the following script

deepspeed ../train_sft_selector.py \
    --max_len 2048 \
    --dataset <upper-level safe dataset> \ 
    --new_dataset <original fine-tuning dataset> \ 
    --upperlevel_weight <initial safe loss weight, between (0,1]> \
    --upperlevel_weight_decay <weight decay each epoch> \
    --train_batch_size 64 \
    --micro_train_batch_size 1 \
    --max_samples <dataset size limit>\
    --pretrain <aligned model> \
    --selector_activation <softmax or sigmoid> \
    --selector_name <selector name> \
    --save_steps -1 \
    --logging_steps 1 \
    --eval_steps -1 \
    --zero_stage 3 \
    --max_epochs 3 \
    --bf16 \
    --learning_rate 1e-5 \
    --selector_learning_rate 5e-3 \
    --selector_lr_scheduler <deepspeed lr scheduler, e.g., constant> \
    --lr_scheduler <deepspeed lr scheduler, e.g., cosine> \
    --gradient_checkpointing \
    --flash_attn \
    --lora_rank 16 \
    --lora_alpha 16 \
    --target_modules q_proj v_proj

For example, we can run on Llama-3-8b-Instruct with our default setting:

We provide data selectors trained by us in the ckpt folder. Skip this for a quick run.

# SEAL data selector training
bash train_selector_llama3.sh

Fine-tuning stage

Then we run SFT with SEAL data selection

deepspeed ../train_sft.py \
    --max_len 2048 \
    --dataset <original fine-tuning dataset> \ 
    --selector_path <data selector path>\
    --topp <between (0,1], data selection rate> \
    --train_batch_size 64 \
    --micro_train_batch_size 1 \
    --max_samples  <dataset size limit> \
    --pretrain <initial model>\
    --save_path <save path>\
    --save_steps -1 \
    --logging_steps 1 \
    --eval_steps -1 \
    --zero_stage 3 \
    --max_epochs 3 \
    --bf16 \
    --lr_scheduler <deepspeed lr scheduler, e.g., cosine>\
    --learning_rate 1e-5 \
    --gradient_checkpointing \
    --flash_attn \
    --lora_rank 16 \
    --lora_alpha 32 \
    --target_modules q_proj v_proj

For example, the user can run on Llama-3-8b-Instruct with default setting

# Fine-tuning with data selection
bash train_seal_sft_llama3.sh

To train with SFT without data selection, the user just have to set topp as 1. or not specifying the selector_path argument. For example, to run SFT on Llama-3-8b-Instruct without data selection under default setups, use

# standard SFT on Llama-3-8b-Instruct
bash train_sft_llama3.sh

Citation

If you find our work interesting, please consider citing this paper:

@article{shen2024seal,
  title={SEAL: Safety-enhanced Aligned LLM Fine-tuning via Bilevel Data Selection},
  author={Shen, Han and Chen, Pin-Yu and Das, Payel and Chen, Tianyi},
  journal={arXiv preprint arXiv:2410.07471},
  year={2024}
}

About

An implementation of SEAL: Safety-Enhanced Aligned LLM fine-tuning via bilevel data selection.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published