- Overview
- Installation
- Dataset Preparation
- Model Preparation
- Training Pipeline
- Repository Structure
- Acknowledgements
- Citation
- License
TIS-DPO enhances Direct Preference Optimization by incorporating token-level importance sampling. While standard DPO treats the entire response as a single unit, TIS-DPO recognizes that not all tokens contribute equally to response quality.
Our approach assigns importance weights to each token based on its estimated reward, focusing optimization on the most critical parts of the response. These weights are estimated using the difference in prediction probabilities from a pair of contrastive LLMs trained with forward and reverse DPO.
Create and activate the conda environment:
conda env create -f environment.yml
conda activate tis-dpoDownload the required dataset from Hugging Face:
huggingface-cli download --resume-download exlaw/tis-dpo-data --local-dir datasets --repo-type=datasetThe data should be organized in the datasets/ directory.
This repository uses the Qwen2.5-3B model as an example.
Download our fine-tuned starting model:
huggingface-cli download --resume-download exlaw/Qwen2.5-3B-sft --local-dir models/Qwen2.5-3B-sftThis model has been trained on the Alpaca dataset to provide instruction following capabilities.
Train the standard DPO model:
python -u train.py model=qwen model.name_or_path=models/Qwen/Qwen2.5-3B-sft \
datasets=[ultra-feedback] loss=dpo loss.beta=0.1 \
gradient_accumulation_steps=2 batch_size=32 eval_batch_size=32 \
trainer=FSDPTrainer sample_during_eval=false \
base_data_dir=datasets/ reverse_dataset=falseThe resulting model will be saved at output/dpo_Qwen2.5-3B-sft_ultra-feedback_{timestamp}
Train the reverse DPO model:
python -u train.py model=qwen model.name_or_path=models/Qwen/Qwen2.5-3B-sft \
datasets=[ultra-feedback] loss=dpo loss.beta=0.1 \
gradient_accumulation_steps=2 batch_size=32 eval_batch_size=32 \
trainer=FSDPTrainer sample_during_eval=false \
base_data_dir=datasets/ reverse_dataset=trueThe resulting model will be saved at output/dpo_Qwen2.5-3B-sft_ultra-feedback_reverse_{timestamp}
First, set the paths to your trained models (replacing {timestamp} with the actual values):
export MODEL_NAME_1="output/dpo_Qwen2.5-3B-sft_ultra-feedback_{timestamp}"
export MODEL_NAME_2="output/dpo_Qwen2.5-3B-sft_ultra-feedback_reverse_{timestamp}"Then run the token weight estimation:
bash scripts/token_weight_estimation.shTrain using the estimated token weights:
python -u train.py model=qwen model.name_or_path=models/Qwen/Qwen2.5-3B-sft \
datasets=[ultra-feedback-tisdpo] loss=tisdpo loss.beta=0.1 \
gradient_accumulation_steps=2 batch_size=32 eval_batch_size=32 \
trainer=FSDPTrainer sample_during_eval=false \
base_data_dir=generated-data/ reverse_dataset=false \
transform.method=rank_basedAvailable transformation methods include:
rank_basedrandomthreshold_and_scalebinaryoriginthreshold
The final model will be saved at output/tisdpo_Qwen2.5-3B-sft_ultra-feedback_tisdpo_{transform_method}_{timestamp}
.
├── README.md # Project documentation
├── LICENSE # Apache 2.0 license
├── environment.yml # Conda environment specification
├── train.py # Main training script
├── utils.py # Utility functions
├── token_weight_estimation.py # Script for token weight estimation
├── preference_datasets.py # Dataset loading and preprocessing
├── trainers.py # Training implementations
├── transform_config.py # Token weight transformation methods
├── config/ # Configuration files
│ ├── config.yaml # Main configuration
│ ├── transform/ # Transformation configurations
│ ├── model/ # Model configurations
│ └── loss/ # Loss function configurations
├── scripts/ # Utility scripts
│ └── token_weight_estimation.sh # Token weight estimation script
├── models/ # Directory for storing models
├── datasets/ # Directory for storing datasets
├── generated-data/ # Directory for storing generated data
└── output/ # Directory for training outputs
We would like to thank the authors of DPO for their foundational work.
If you find this work useful, please consider citing:
@inproceedings{
liu2025tisdpo,
title={{TIS}-{DPO}: Token-level Importance Sampling for Direct Preference Optimization With Estimated Weights},
author={Aiwei Liu and Haoping Bai and Zhiyun Lu and Yanchao Sun and Xiang Kong and Xiaoming Simon Wang and Jiulong Shan and Albin Madappally Jose and Xiaojiang Liu and Lijie Wen and Philip S. Yu and Meng Cao},
booktitle={The Thirteenth International Conference on Learning Representations},
year={2025},
url={https://openreview.net/forum?id=oF6e2WwxX0}
}This project is licensed under the Apache License 2.0 - see the LICENSE file for details.
