Code for the paper "Predict-Project-Renoise: Sampling Diffusion Models under Hard Constraints".
Consider the code a WIP that will get updated once the final version of the paper is published. This code itself is a cleaned-up version that has been tested, but not exhaustively. If you find any issues, have problems running the code, or have any general questions, do not hesitate to open an issue.
Each file has a small header describing its purpuse. These headers and some of the docstrings were automatically generated and may be inaccurate. Again, feel free to open an issue if anything is unclear.
-
Create a conda environment
conda create --name ppr python=3.12 conda activate ppr
-
Install in editable mode
pip install -e .For CUDA support:
pip install -e .[cuda]
ppr/ # Core library
diffusion/ # Diffusion components
denoisers.py # Denoisers (Appa-like, MMPS)
samplers.py # Samplers (DDIM, PC, PPR)
baselines.py # Baseline methods (DPS, MMPS, TS, CPS, PDM, HDC, PPR wrapper)
solvers.py # ODE/SDE solver
schedules.py # Noise schedule
networks/ # Networks
unet.py # UNet (KS experiment)
mod_mlp.py # Modulated MLP (Data2D experiment)
common_layers.py # Shared layers (SineEncoding, MLP, Modulation)
projection.py # PPR projection
train_functions.py # Shared training loop and setup
utils.py # Data loading, checkpointing, misc helpers
numerics/
solvers.py # Pseudo-spectral KS equation solver
experiments/
data2d/ # 2D distribution experiment
exp_def.py # Training configs
eval_config.py # Evaluation configs (methods, ablations)
train.py # Training entry point
eval.py # Evaluation entry point
constraints.py # Constraint function definitions
generation.py # Sample generation logic
compute_graph_metrics.py# MST/KNN graph-based metrics
ks/ # Kuramoto-Sivashinsky experiment
exp_def.py # Training configs
eval_config.py # Evaluation configs
train.py # Training entry point
eval.py # Evaluation entry point
generate.py # KS simulation data generation
generate_diffusion_prior.py # Unconditional diffusion sampling
utils.py # Data paths and helpers
scripts/ # Plotting scripts for paper figures (jupytext format)
The PPR logic is implemented in ppr/diffusion/samplers.py:create_ppr_sampler and ppr/projection.py:make_ppr_projection. A convenience wrapper combining both is available in ppr/diffusion/baselines.py:create_PPR_sampler.
ppr/train_functions.py contains the shared training loop (optimizer setup, checkpointing, W&B logging, validation). Experiment-specific training and evaluation scripts live in experiments/.
All experiment scripts are designed to be run from inside the corresponding experiment directory (e.g., cd experiments/ks).
Jobs are launched via dawgz, a lightweight job scheduler that supports local execution (--backend async) and SLURM cluster submission (--backend slurm). The --backend async option works reliably for training but may fail on some evaluations due jax vram shenanigans.
Each experiment follows the same pattern:
- Define configs in
exp_def.py(training) andeval_config.py(evaluation). Configs are plain Python dicts; named configs inherit from aDEFAULTdict. Inspect these files to see all available parameters. - Train a model with
train.py. - Evaluate constrained sampling with
eval.py.
SLURM resource requirements (CPUs, GPUs, RAM, wall-time) are set directly in the dawgz.job(...) calls inside each script — edit them there to match your cluster.
Train a model on 2D distributions (checkerboard, banana) and evaluate constrained sampling.
cd experiments/data2d
# Train (configs defined in exp_def.py: checkerboard, banana)
python train.py --backend slurm --exp-name checkerboard --partition <your-partition>
# Evaluate (configs defined in eval_config.py: DEBUG_CONFIG, EVAL_CONFIG, RENOISE_CONFIG, ...)
python eval.py --backend slurm --exp-name baseline-exp --config-name EVAL_CONFIG
# Compute graph-based metrics (MST/KNN) on saved results
python compute_graph_metrics.py --backend slurm --path results/baseline-exp| Script | Key CLI Arguments (non exhaustive) |
|---|---|
train.py |
--backend {async,slurm}, --exp-name NAME, --partition PART |
eval.py |
--backend {async,slurm}, --exp-name NAME, --config-name CONFIG, --debug, --account, --partition, --time |
compute_graph_metrics.py |
--backend {async,slurm}, --path RESULTS_DIR, --debug, --partition, --account, --time |
- Training configs (
exp_def.py): dataset type, network architecture, optimizer settings, checkpoint path... - Eval configs (
eval_config.py): checkpoint path, sampling methods and their hyperparameters, number of samples, constraint functions... - Constraints (
constraints.py): defines the constraint energy functions used during evaluation. - Checkpoints are saved as Orbax checkpoints in the
ckptpath specified in the training config. - Results are saved as
.pklfiles inresults/<exp-name>/.
Train a model on KS equation data and evaluate inverse-problem solvers.
cd experiments/ks
# 1. Generate simulation data (only needed once)
python generate.py --backend slurm
# 2. Train (configs defined in exp_def.py)
python train.py --backend slurm --exp-name unet --partition <your-partition>
# 3. Generate unconditional samples as a test reference
python generate_diffusion_prior.py --model-path checkpoints/unet --output-path diffusion_prior_test
# 4. Evaluate (configs defined in eval_config.py: DEBUG_CONFIG, EVAL_CONFIG, ...)
python eval.py --backend slurm --exp-name ks --config-name DEBUG_CONFIG| Script | Key CLI Arguments (non exhaustive) |
|---|---|
train.py |
--backend {async,slurm}, --exp-name NAME, --partition PART |
eval.py |
--backend {async,slurm}, --exp-name NAME, --config-name CONFIG, --debug, --num-test N, --sims-per-job N, --account, --partition, --time |
generate.py |
--backend {async,slurm}, --overwrite |
generate_diffusion_prior.py |
--model-path PATH, --output-path PATH, --num-batches N, --batch_size N, --sampling-steps N |
- Data path: by default, simulation data is stored under
experiments/ks/scratch/ks/data/. If the$SCRATCHenvironment variable is set, data goes to$SCRATCH/ppr/ks/data/instead. This is controlled inutils.py. - Eval configs (
eval_config.py): defines observation operators (e.g.,linear,sin,sin2), sampling methods, and ablation sweeps. - Results are saved as
.pklfiles inresults/<exp-name>/.
The weather experiment uses the external Appa codebase, from Andry et al. 2025, and is not included in this repository. I will release code for this experiment in the future.
| I want to... | Where to look |
|---|---|
| Change training hyperparameters | experiments/<exp>/exp_def.py |
| Change evaluation methods or sweeps | experiments/<exp>/eval_config.py |
| Add a new constraint (Data2D) | experiments/data2d/constraints.py |
| Add a new observation operator (KS) | experiments/ks/eval_config.py |
| Add a new sampling method | ppr/diffusion/baselines.py |
| Change SLURM resources (CPUs, GPUs, RAM, time) | dawgz.job(...) calls in train.py / eval.py |
| Change data storage paths (KS) | experiments/ks/utils.py or set $SCRATCH env var |
| Understand the PPR algorithm | ppr/diffusion/samplers.py + ppr/projection.py |
| Understand the training loop | ppr/train_functions.py |
Scripts for reproducing paper figures are in scripts/. These are jupytext-formatted .py files that can be run as scripts or converted to notebooks:
pip install jupytext
jupytext --to .ipynb scripts/plot_ks_samples.py # .py → .ipynb
jupytext --to .py notebook.ipynb # .ipynb → .pyThe plotting scripts expect results at relative paths like ../experiments/<exp>/results/ — adjust these if your directory layout differs.
If you find this code useful, please consider citing our paper:
@misc{rochmansharabi2026predictprojectrenoisesamplingdiffusionmodels,
title={Predict-Project-Renoise: Sampling Diffusion Models under Hard Constraints},
author={Omer Rochman-Sharabi and Gilles Louppe},
year={2026},
eprint={2601.21033},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2601.21033},
}