1ShanghaiTech University
2MoE Key Laboratory of Intelligent Perception and Human-Machine Collaboration
3Fudan University
[arXiv] [Project page]
We reveal that the diffusion bridge with Doob’s
Install the dependencies with Anaconda and activate the environment with:
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
conda create --name UniDB python=3.9
conda activate UniDB
pip install -r requirements.txt
- Prepare datasets.
- Download pretrained checkpoints here
- Modify options, including dataroot_GT, dataroot_LQ and pretrain_model_G.
- Choose a model to sample (Default: UniDB): test function in
codes/models/denoising_model.py. python test.py -opt=options/test.yml
The Test results will be saved in \results.
We computed the average distances between high-quality and low-quality images in the three datasets (CelebA-HQ, Rain100H, and DIV2K) related to the subsequent experimental section as the distances
- Prepare datasets.
- Modify options, including dataroot_GT, dataroot_LQ.
python train.py -opt=options/train.ymlfor single GPU.
torchrun --nproc_per_node=2 --master_port=1111 train.py -opt=options/train.yml --launcher pytorchfor multi GPUs. Attention: see Important Option Details.- For the DIV2K dataset, your GPU memory needs to be greater than 34GB.
- You can modify the parameter of gamma in UniDB-GOU/utils/sde_utils.py to balance the control term and the terminal penalty term in the stochastic optimal control, so that the image can achieve better quality.
Here, we mainly focus on modifying the GOU (Generalized Ornstein-Uhlenbeck) process. For modifications related to VE and VP, readers can refer to the derivations in the appendix of our paper and make the changes themselves (which only require modifying one or two lines of code). We will also release the next version as soon as possible.
The Training log will be saved in \experiments.
We provide the interface.py for the deraining, which can generate HQ only with LQ:
- Prepare options/test.yml filling in LQ path.
python interface.py.- The interface will be on the local server: 127.0.0.1.
Other tasks can also be written in imitation.
dataroot_GT: Ground Truth (High-Quality) data path.dataroot_LQ: Low-Quality data path.pretrain_model_G: Pretraind model path.GT_size, LQ_size: Size of the data cropped during training.niter: Total training iterations.val_freq: Frequency of validation during training.save_checkpoint_freq: Frequency of saving checkpoint during training.gpu_ids: In multi-GPU training, GPU ids are separated by commas in multi-gpu training.batch_size: In multi-GPU training, must satisfy relation: batch_size/num_gpu>1.
We provid a brief guidelines for commputing FID of two set of images:
- Install FID library:
pip install pytorch-fid. - Commpute FID:
python -m pytorch_fid GT_images_file_path generated_images_file_path --batch-size 1
if all the images are the same size, you can remove--batch-size 1to accelerate commputing.
If you find this repository useful in your research, please consider citing:
@inproceedings{
zhu2025unidb,
title={Uni{DB}: A Unified Diffusion Bridge Framework via Stochastic Optimal Control},
author={Kaizhen Zhu and Mokai Pan and Yuexin Ma and Yanwei Fu and Jingyi Yu and Jingya Wang and Ye Shi},
booktitle={Forty-second International Conference on Machine Learning},
year={2025},
url={https://openreview.net/forum?id=uqCfoVXb67}
}



