Alan N Amin, Nate Gruver, Andrew Gordon Wilson.
Masking discrete diffusion makes use of a fundamental difference between continuous and discrete Markov processes: discrete Markov processes evolve by discontinuous jumps at a fixed rate and, unlike other discrete diffusion models, masking diffusion builds in the known distribution of jump times and only learns where to jump to. We show that we can similarly bake in the known distribution of jump times into any discrete diffusion model. The resulting models -- schedule-conditioned diffusion (SCUD) -- generalize classical discrete diffusion and masking diffusion. By applying SCUD to models with noising processes that incorporate inductive biases on images, text, and protein data, we build diffusion models that outperform masking.
This codebase implements schedule-conditioned diffusion (SCUD). We provide instructions to train models on image and protein data. We also include code to train masking diffusion or classical diffusion models.
Install dependencies by running pip install . with a recent version of Python.
To train a small U-net on CIFAR10 with 64 states, run python3 train.py.
To train protein models, you can download Uniref50 data from here. Place this data in data/uniref_2020/uniref50/.
Also download the BLOSUM62 matrix from here and place it in data/blosum62-special-MSA.mat.
Then you can train a SCUD model with a small CARP architecture by running python3 train.py --config-name=basic_protein.
You can change the hyperparameters of the model, and even train masking and classical discrete diffusion models by modifying the config file configs/basic.cfg.
To change the training parameters, modify the train parameters.
To change the model architecture, modify architecture parameters.
Changing data.N for image data changes the number of states in CIFAR images (up to 256).
model.model can be set to SCUD, Masking, or Classical.
model.gamma controls the conditioning parameter model.schedule_type controls the noise rate function linear, cos, or our choice from the paper: mutual_information (note below for this choice).
model.forward_kwargs controls the forward process; note the get_inf_gen function in scud/utils.py for choices.
model.logistic_pars toggles the logistic parametersation for image data.
Set model.restart to the folder of a checkpoint to restart training.
We choose our function data/save_alphas before begining training.
Be sure to account for this taking up to an hour the first time you train on a new set of data (calculating this schedule for classical discrete diffusion can take much longer).
