Official JAX codebase for the ImageNet experiments of Generative Modeling via Drifting. We provide training, inference, and pretrained weights for one-step image generation on ImageNet 256×256.
Uncurated conditional ImageNet 256×256 samples (1 NFE, CFG scale 1.0, FID 1.54):
The generated distribution q evolves toward the data distribution p during training. Try the interactive toy demo to see the algorithm in action:
| Middle Init | Far-Away Init | Collapsed Init |
|---|---|---|
![]() |
![]() |
![]() |
- Quick Start (Inference)
- Pretrained Models
- Environment Setup
- FID Evaluation
- Training
- Checkpoints and Logs
- Citation
The self-contained Colab notebook lets you generate samples interactively — no local setup required:
Default notebook configuration:
init_from = hf://latent_L_sotaclass_ids = 95,22,88,108,386,296,483,698
Class indices follow the ImageNet-1k label order.
| Model | Space | Feature Encoder | Encoder HF ID | Generator HF ID | CFG | FID (repo / paper) | IS (repo / paper) |
|---|---|---|---|---|---|---|---|
| Drift-L | latent | MAE-640 (latent) | hf://mae_latent_640 |
hf://latent_L_sota |
1.0 | 1.53 / 1.54 | 260.1 / 258.9 |
| Drift-B | latent | MAE-640 (latent) | hf://mae_latent_640 |
hf://latent_B_sota |
1.1 | 1.74 / 1.75 | 263.4 / 263.2 |
| Drift-L | pixel | MAE-640 (pixel) | hf://mae_pixel_640 |
hf://pixel_L_sota |
1.0 | 1.62 / 1.61 | 308.6 / 307.5 |
| Drift-B | pixel | MAE-640 (pixel) | hf://mae_pixel_640 |
hf://pixel_B_sota |
1.0 | 1.73 / 1.76 | 300.1 / 299.7 |
| Ablation | latent | MAE-256 (latent) | hf://mae_latent_256 |
hf://ablation |
2.0 | 8.49 / 8.46 | 144.0 / — |
| Model | Space | HF ID |
|---|---|---|
| MAE-640 (latent) | latent | hf://mae_latent_640 |
| MAE-640 (pixel) | pixel | hf://mae_pixel_640 |
| MAE-256 (ablation) | latent | hf://mae_latent_256 |
All artifacts are hosted on HuggingFace at Goodeat/drifting and are downloaded automatically.
conda create -n drifting-release python=3.10 -y
conda activate drifting-release
pip install -r requirements.txt
export JAX_PLATFORMS=tpu,cpuFor local TPU runs, keep JAX_PLATFORMS=tpu,cpu in the shell before running
latent-cache building, training, or evaluation. This keeps TPU as the default
backend while still exposing a CPU backend for Flax VAE / checkpoint restore
paths that expect it.
Download the ImageNet dataset and extract it to your desired location. The dataset should have the following structure:
imagenet/
├── train/
│ ├── n01440764/
│ ├── n01443537/
│ └── ...
└── val/
├── n01440764/
├── n01443537/
└── ...
Before running training or evaluation, open utils/env.py and set these constants for your machine:
IMAGENET_PATH: root of the ImageNet directory (expectstrain/andval/subdirectories).IMAGENET_CACHE_PATH: root of the latent cache directory (only needed for latent-generator training).IMAGENET_FID_NPZ: path to the ImageNet-256 FID reference stats.npz.IMAGENET_PR_NPZ: path to the ImageNet precision/recall reference stats.npz.HF_ROOT: local cache directory for downloaded HuggingFace artifacts.HF_REPO_ID: HuggingFace repo ID for the release checkpoints (keep asGoodeat/drifting).
FID/PR reference stats can be downloaded from Google Drive (migrated from MeanFlow).
Only needed for latent-space generators.
python -m dataset.latent \
--data-path /path/to/imagenet \
--target-path /path/to/latent_cache \
--local-batch-size 128 \
--num-workers 8 \
--pin-memoryThis encodes ImageNet images through the VAE and writes .pt files to /path/to/latent_cache/{train,val}/. After building the cache, update IMAGENET_CACHE_PATH in utils/env.py.
Reproduce paper FID numbers on ImageNet-256 (50k samples, CFG=1.0):
# Latent model
python inference.py --init-from "hf://latent_L_sota" --cfg-scale 1.0 \
--num-samples 50000 --eval-batch-size 256 --json-out results_latent.json
# Pixel model
python inference.py --init-from "hf://pixel_L_sota" --cfg-scale 1.0 \
--num-samples 50000 --eval-batch-size 256 --json-out results_pixel.jsonTo stream metrics and preview images to W&B, add:
--use-wandb --wandb-entity YOUR_ENTITY_HERE --wandb-project YOUR_PROJECT_HEREExpected FID numbers match the Pretrained Models table above. Output JSON contains fid, isc_mean, isc_std, precision, recall. Precision/recall are only computed when num_samples >= 50000.
Requirements:
- TPU v4-8 (otherwise reduce
--eval-batch-sizeto avoid OOM during VAE decoding) - ImageNet-256 path configured in
utils/env.py. Images are generated using the class labels from the ImageNet validation set. - Precomputed FID/PR reference stats configured in
utils/env.py
python main.py --gen --config configs/gen/latent_ablation.yaml --workdir runs/gen_latent_ablation
python main.py --gen --config configs/gen/latent_sota_B.yaml --workdir runs/gen_latent_sota_B
python main.py --gen --config configs/gen/latent_sota_L.yaml --workdir runs/gen_latent_sota_L
python main.py --gen --config configs/gen/pixel_sota_B.yaml --workdir runs/gen_pixel_sota_B
python main.py --gen --config configs/gen/pixel_sota_L.yaml --workdir runs/gen_pixel_sota_LMAE pretrained weights are downloaded automatically from HuggingFace via the feature.mae_path config field. No need to train MAE unless experimenting with custom feature extractors.
FID is evaluated during training at intervals set by train.eval_per_step.
Ablation run intermediate FID (EMA model, best CFG):
| Steps | CFG | FID |
|---|---|---|
| 5k | 3.5 | 35.20 |
| 10k | 2.5 | 13.33 |
| 15k | 2.0 | 10.70 |
| 20k | 2.0 | 9.47 |
| 25k | 2.0 | 8.84 |
| 30k | 2.0 | 8.34 |
We used 64 TPU v6e for the ablation run and 128 TPU v6e for the SOTA runs. Each host maintains its own memory bank (16 hosts for ablation, 32 for SOTA). When using fewer hosts (e.g., DDP on one H100 node = 8 hosts), increase push_per_step to keep the memory bank update rate sufficient.
Pretrained MAE weights are already available at hf://mae_latent_640, hf://mae_latent_256, and hf://mae_pixel_640. Training code is provided for users who want to train their own:
python main.py --config configs/mae/latent_ablation_256.yaml --workdir runs/mae_latent_ablation_256
python main.py --config configs/mae/latent_640.yaml --workdir runs/mae_latent_640
python main.py --config configs/mae/pixel_640.yaml --workdir runs/mae_pixel_640- Train an MAE (see above).
- Point the generator config at the MAE workdir:
feature:
mae_path: /abs/path/to/runs/mae_latent_640
use_mae: true
use_convnext: false
use_post_x: false- Run generator training.
Each --workdir <dir> produces:
<dir>/
├── checkpoints/ # Orbax checkpoints (full training state)
├── params_ema/ # EMA-only artifact
│ ├── ema_params.msgpack
│ └── metadata.json
└── log/
├── metrics.jsonl # Metrics (when use_wandb: false)
└── images/*.jpg # Sample preview grids
Local artifacts in params_ema/ can be loaded directly for inference:
python inference.py --init-from /path/to/workdir --cfg-scale 1.0 \
--num-samples 50000 --eval-batch-size 256@article{deng2026generative,
title={Generative Modeling via Drifting},
author={Deng, Mingyang and Li, He and Li, Tianhong and Du, Yilun and He, Kaiming},
journal={arXiv preprint arXiv:2602.04770},
year={2026}
}We thank Hanhong Zhao for sanity checking this repository.















