[NeurIPS 2024] Official pytorch implementation of the paper: "Classification Diffusion Models: Revitalizing Density Ratio Estimaion"
Shahar Yadin, Noam Elata, Tomer Michaeli, Technion - Israel Institute of Technology.
A DDM functions as an MMSE denoiser conditioned on the noise level, whereas a CDM operates as a classifier. Given a noisy image, a CDM outputs a probability vector predicting the noise level, such that the
Samples from CDMs trained on CelebA 64x64 and on CIFAR-10
To install the necessary dependencies, install the requirements file or run the following command:
pip install torch torchvision accelerate wandb matplotlibTo train a CDM, change the mode in config/config.yaml file to "training". For CIFAR-10, no additional setup is required.
If you wish to train the model on the CelebA dataset.
- Download the CelebA dataset from here.
- Add the dataset path to the dataset_path field in config/config.yaml.
To generate random samples, please first train CDM (as described above), then:
- Change the mode in config/config.yaml to "sampling".
- Add the checkpoint folder and file to the ckpt_folder and ckpt_file fields in config/config.yaml.
The sampling method and the number of sampling steps can be controlled via the sampler and num_sampling_steps fields in config/config.yaml.
To generate random samples, please first train CDM (as described above), then:
- Change the mode in config/config.yaml to "likelihood_eval".
- Add the checkpoint folder and file to the ckpt_folder and ckpt_file fields in config/config.yaml.
All operations (training, sampling, and likelihood evaluation) can be run using one of the following commands:
python main.pyaccelerate launch main.py
If you use this code for your research, please cite our paper:
@article{yadin2024classification,
title={Classification Diffusion Models},
author={Yadin, Shahar and Elata, Noam and Michaeli, Tomer},
journal={arXiv preprint arXiv:2402.10095},
year={2024}
}
The code in models/diffusion.py was adapted from the following Denoising Diffusion Implicit Models.


