Skip to content

OmerRochman/ppr

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Predict-Project-Renoise (PPR)

Code for the paper "Predict-Project-Renoise: Sampling Diffusion Models under Hard Constraints".

Consider the code a WIP that will get updated once the final version of the paper is published. This code itself is a cleaned-up version that has been tested, but not exhaustively. If you find any issues, have problems running the code, or have any general questions, do not hesitate to open an issue.

Each file has a small header describing its purpuse. These headers and some of the docstrings were automatically generated and may be inaccurate. Again, feel free to open an issue if anything is unclear.

Installation

  1. Create a conda environment

    conda create --name ppr python=3.12
    conda activate ppr
  2. Install in editable mode

    pip install -e .

    For CUDA support:

    pip install -e .[cuda]

Repository Structure

ppr/                        # Core library
  diffusion/                # Diffusion components
    denoisers.py            #   Denoisers (Appa-like, MMPS)
    samplers.py             #   Samplers (DDIM, PC, PPR)
    baselines.py            #   Baseline methods (DPS, MMPS, TS, CPS, PDM, HDC, PPR wrapper)
    solvers.py              #   ODE/SDE solver
    schedules.py            #   Noise schedule
    networks/               #   Networks
      unet.py               #     UNet (KS experiment)
      mod_mlp.py            #     Modulated MLP (Data2D experiment)
      common_layers.py      #     Shared layers (SineEncoding, MLP, Modulation)
  projection.py             # PPR projection
  train_functions.py        # Shared training loop and setup
  utils.py                  # Data loading, checkpointing, misc helpers
numerics/
  solvers.py                # Pseudo-spectral KS equation solver
experiments/
  data2d/                   # 2D distribution experiment
    exp_def.py              #   Training configs
    eval_config.py          #   Evaluation configs (methods, ablations)
    train.py                #   Training entry point
    eval.py                 #   Evaluation entry point
    constraints.py          #   Constraint function definitions
    generation.py           #   Sample generation logic
    compute_graph_metrics.py#   MST/KNN graph-based metrics
  ks/                       # Kuramoto-Sivashinsky experiment
    exp_def.py              #   Training configs
    eval_config.py          #   Evaluation configs
    train.py                #   Training entry point
    eval.py                 #   Evaluation entry point
    generate.py             #   KS simulation data generation
    generate_diffusion_prior.py # Unconditional diffusion sampling
    utils.py                #   Data paths and helpers
scripts/                    # Plotting scripts for paper figures (jupytext format)

The PPR logic is implemented in ppr/diffusion/samplers.py:create_ppr_sampler and ppr/projection.py:make_ppr_projection. A convenience wrapper combining both is available in ppr/diffusion/baselines.py:create_PPR_sampler.

ppr/train_functions.py contains the shared training loop (optimizer setup, checkpointing, W&B logging, validation). Experiment-specific training and evaluation scripts live in experiments/.

Running Experiments

Overview

All experiment scripts are designed to be run from inside the corresponding experiment directory (e.g., cd experiments/ks).

Jobs are launched via dawgz, a lightweight job scheduler that supports local execution (--backend async) and SLURM cluster submission (--backend slurm). The --backend async option works reliably for training but may fail on some evaluations due jax vram shenanigans.

Each experiment follows the same pattern:

  1. Define configs in exp_def.py (training) and eval_config.py (evaluation). Configs are plain Python dicts; named configs inherit from a DEFAULT dict. Inspect these files to see all available parameters.
  2. Train a model with train.py.
  3. Evaluate constrained sampling with eval.py.

SLURM resource requirements (CPUs, GPUs, RAM, wall-time) are set directly in the dawgz.job(...) calls inside each script — edit them there to match your cluster.

Data2D

Train a model on 2D distributions (checkerboard, banana) and evaluate constrained sampling.

cd experiments/data2d

# Train (configs defined in exp_def.py: checkerboard, banana)
python train.py --backend slurm --exp-name checkerboard --partition <your-partition>

# Evaluate (configs defined in eval_config.py: DEBUG_CONFIG, EVAL_CONFIG, RENOISE_CONFIG, ...)
python eval.py --backend slurm --exp-name baseline-exp --config-name EVAL_CONFIG

# Compute graph-based metrics (MST/KNN) on saved results
python compute_graph_metrics.py --backend slurm --path results/baseline-exp
Script Key CLI Arguments (non exhaustive)
train.py --backend {async,slurm}, --exp-name NAME, --partition PART
eval.py --backend {async,slurm}, --exp-name NAME, --config-name CONFIG, --debug, --account, --partition, --time
compute_graph_metrics.py --backend {async,slurm}, --path RESULTS_DIR, --debug, --partition, --account, --time
  • Training configs (exp_def.py): dataset type, network architecture, optimizer settings, checkpoint path...
  • Eval configs (eval_config.py): checkpoint path, sampling methods and their hyperparameters, number of samples, constraint functions...
  • Constraints (constraints.py): defines the constraint energy functions used during evaluation.
  • Checkpoints are saved as Orbax checkpoints in the ckpt path specified in the training config.
  • Results are saved as .pkl files in results/<exp-name>/.

Kuramoto-Sivashinsky

Train a model on KS equation data and evaluate inverse-problem solvers.

cd experiments/ks

# 1. Generate simulation data (only needed once)
python generate.py --backend slurm

# 2. Train (configs defined in exp_def.py)
python train.py --backend slurm --exp-name unet --partition <your-partition>

# 3. Generate unconditional samples as a test reference
python generate_diffusion_prior.py --model-path checkpoints/unet --output-path diffusion_prior_test

# 4. Evaluate (configs defined in eval_config.py: DEBUG_CONFIG, EVAL_CONFIG, ...)
python eval.py --backend slurm --exp-name ks --config-name DEBUG_CONFIG
Script Key CLI Arguments (non exhaustive)
train.py --backend {async,slurm}, --exp-name NAME, --partition PART
eval.py --backend {async,slurm}, --exp-name NAME, --config-name CONFIG, --debug, --num-test N, --sims-per-job N, --account, --partition, --time
generate.py --backend {async,slurm}, --overwrite
generate_diffusion_prior.py --model-path PATH, --output-path PATH, --num-batches N, --batch_size N, --sampling-steps N
  • Data path: by default, simulation data is stored under experiments/ks/scratch/ks/data/. If the $SCRATCH environment variable is set, data goes to $SCRATCH/ppr/ks/data/ instead. This is controlled in utils.py.
  • Eval configs (eval_config.py): defines observation operators (e.g., linear, sin, sin2), sampling methods, and ablation sweeps.
  • Results are saved as .pkl files in results/<exp-name>/.

Weather / Appa

The weather experiment uses the external Appa codebase, from Andry et al. 2025, and is not included in this repository. I will release code for this experiment in the future.

Customization (examples)

I want to... Where to look
Change training hyperparameters experiments/<exp>/exp_def.py
Change evaluation methods or sweeps experiments/<exp>/eval_config.py
Add a new constraint (Data2D) experiments/data2d/constraints.py
Add a new observation operator (KS) experiments/ks/eval_config.py
Add a new sampling method ppr/diffusion/baselines.py
Change SLURM resources (CPUs, GPUs, RAM, time) dawgz.job(...) calls in train.py / eval.py
Change data storage paths (KS) experiments/ks/utils.py or set $SCRATCH env var
Understand the PPR algorithm ppr/diffusion/samplers.py + ppr/projection.py
Understand the training loop ppr/train_functions.py

Plotting

Scripts for reproducing paper figures are in scripts/. These are jupytext-formatted .py files that can be run as scripts or converted to notebooks:

pip install jupytext
jupytext --to .ipynb scripts/plot_ks_samples.py   # .py → .ipynb
jupytext --to .py notebook.ipynb                   # .ipynb → .py

The plotting scripts expect results at relative paths like ../experiments/<exp>/results/ — adjust these if your directory layout differs.

Citation

If you find this code useful, please consider citing our paper:

  @misc{rochmansharabi2026predictprojectrenoisesamplingdiffusionmodels,
        title={Predict-Project-Renoise: Sampling Diffusion Models under Hard Constraints}, 
        author={Omer Rochman-Sharabi and Gilles Louppe},
        year={2026},
        eprint={2601.21033},
        archivePrefix={arXiv},
        primaryClass={cs.LG},
        url={https://arxiv.org/abs/2601.21033}, 
  }

About

Code for the Predict-Project-Renoise paper

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages