Skip to content
/ TIS-DPO Public

Code and data for paper "TIS-DPO: Token-level Importance Sampling for Direct Preference Optimization With Estimated Weights" Accepted by ICLR 2025

License

Notifications You must be signed in to change notification settings

exlaw/TIS-DPO

Repository files navigation

TIS-DPO: Token-level Importance Sampling for Direct Preference Optimization With Estimated Weights

arXiv ICLR License Python 3.8+ HuggingFace Datasets HuggingFace Model

📌 Table of Contents

🔍 Overview

TIS-DPO Method Overview

TIS-DPO enhances Direct Preference Optimization by incorporating token-level importance sampling. While standard DPO treats the entire response as a single unit, TIS-DPO recognizes that not all tokens contribute equally to response quality.

Our approach assigns importance weights to each token based on its estimated reward, focusing optimization on the most critical parts of the response. These weights are estimated using the difference in prediction probabilities from a pair of contrastive LLMs trained with forward and reverse DPO.

🔧 Installation

Create and activate the conda environment:

conda env create -f environment.yml
conda activate tis-dpo

📊 Dataset Preparation

Download the required dataset from Hugging Face:

huggingface-cli download --resume-download exlaw/tis-dpo-data --local-dir datasets --repo-type=dataset

The data should be organized in the datasets/ directory.

🤖 Model Preparation

This repository uses the Qwen2.5-3B model as an example.

Download our fine-tuned starting model:

huggingface-cli download --resume-download exlaw/Qwen2.5-3B-sft --local-dir models/Qwen2.5-3B-sft

This model has been trained on the Alpaca dataset to provide instruction following capabilities.

🚀 Training Pipeline

Step 1: Training DPO and Reverse DPO Models

Train the standard DPO model:

python -u train.py model=qwen model.name_or_path=models/Qwen/Qwen2.5-3B-sft \
  datasets=[ultra-feedback] loss=dpo loss.beta=0.1 \
  gradient_accumulation_steps=2 batch_size=32 eval_batch_size=32 \
  trainer=FSDPTrainer sample_during_eval=false \
  base_data_dir=datasets/ reverse_dataset=false

The resulting model will be saved at output/dpo_Qwen2.5-3B-sft_ultra-feedback_{timestamp}

Train the reverse DPO model:

python -u train.py model=qwen model.name_or_path=models/Qwen/Qwen2.5-3B-sft \
  datasets=[ultra-feedback] loss=dpo loss.beta=0.1 \
  gradient_accumulation_steps=2 batch_size=32 eval_batch_size=32 \
  trainer=FSDPTrainer sample_during_eval=false \
  base_data_dir=datasets/ reverse_dataset=true

The resulting model will be saved at output/dpo_Qwen2.5-3B-sft_ultra-feedback_reverse_{timestamp}

Step 2: Token Weight Estimation

First, set the paths to your trained models (replacing {timestamp} with the actual values):

export MODEL_NAME_1="output/dpo_Qwen2.5-3B-sft_ultra-feedback_{timestamp}"
export MODEL_NAME_2="output/dpo_Qwen2.5-3B-sft_ultra-feedback_reverse_{timestamp}"

Then run the token weight estimation:

bash scripts/token_weight_estimation.sh

Step 3: Training with TIS-DPO

Train using the estimated token weights:

python -u train.py model=qwen model.name_or_path=models/Qwen/Qwen2.5-3B-sft \
  datasets=[ultra-feedback-tisdpo] loss=tisdpo loss.beta=0.1 \
  gradient_accumulation_steps=2 batch_size=32 eval_batch_size=32 \
  trainer=FSDPTrainer sample_during_eval=false \
  base_data_dir=generated-data/ reverse_dataset=false \
  transform.method=rank_based

Available transformation methods include:

  • rank_based
  • random
  • threshold_and_scale
  • binary
  • origin
  • threshold

The final model will be saved at output/tisdpo_Qwen2.5-3B-sft_ultra-feedback_tisdpo_{transform_method}_{timestamp}

📝 Repository Structure

.
├── README.md                  # Project documentation
├── LICENSE                    # Apache 2.0 license
├── environment.yml            # Conda environment specification
├── train.py                   # Main training script
├── utils.py                   # Utility functions
├── token_weight_estimation.py # Script for token weight estimation
├── preference_datasets.py     # Dataset loading and preprocessing
├── trainers.py                # Training implementations
├── transform_config.py        # Token weight transformation methods
├── config/                    # Configuration files
│   ├── config.yaml            # Main configuration
│   ├── transform/             # Transformation configurations
│   ├── model/                 # Model configurations
│   └── loss/                  # Loss function configurations
├── scripts/                   # Utility scripts
│   └── token_weight_estimation.sh  # Token weight estimation script
├── models/                    # Directory for storing models
├── datasets/                  # Directory for storing datasets
├── generated-data/            # Directory for storing generated data
└── output/                    # Directory for training outputs

🙏 Acknowledgements

We would like to thank the authors of DPO for their foundational work.

📄 Citation

If you find this work useful, please consider citing:

@inproceedings{
liu2025tisdpo,
title={{TIS}-{DPO}: Token-level Importance Sampling for Direct Preference Optimization With Estimated Weights},
author={Aiwei Liu and Haoping Bai and Zhiyun Lu and Yanchao Sun and Xiang Kong and Xiaoming Simon Wang and Jiulong Shan and Albin Madappally Jose and Xiaojiang Liu and Lijie Wen and Philip S. Yu and Meng Cao},
booktitle={The Thirteenth International Conference on Learning Representations},
year={2025},
url={https://openreview.net/forum?id=oF6e2WwxX0}
}

📜 License

This project is licensed under the Apache License 2.0 - see the LICENSE file for details.

About

Code and data for paper "TIS-DPO: Token-level Importance Sampling for Direct Preference Optimization With Estimated Weights" Accepted by ICLR 2025

Resources

License

Contributing

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published