This package contains the adapter library for fine-tuning using the Hackable Diffusion library.
We recommend Python 3.12 and CUDA 13 for this project.
First install the gemma package.
From PyPI (Recommended)
pip install gemmaFrom Source
git clone https://github.com/google-deepmind/gemma.git
cd gemma
pip install .Then we additionally require jax[cuda13] dependencies that can be installed
via
pip install -U jax[cuda13]Note
We have tested the library with CUDA 13, other versions can cause
NCCL errors. Don't mix in other CUDA 12 packages such as
jax-cuda12-plugin or nvidia-nccl-cu12.
Before training, you need to download and prepare the datasets. Note that running these scripts requires first cloning the source repository.
To prepare the PubMedQA dataset, make sure your Python environment is activated, then run:
cd gemma/diffusion/hackable_diffusion_adapter/data/pubmedqa
bash prepare_pubmedqa_dataset.shTo prepare the Sudoku dataset, you first need to configure your Kaggle API credentials:
-
Generate your Kaggle access token.
-
Set up the token on your machine:
mkdir -p ~/.kaggle && echo YOUR_KAGGLE_TOKEN > ~/.kaggle/access_token && chmod 600 ~/.kaggle/access_token- Make sure your Python environment is activated, then run the preparation script:
cd gemma/diffusion/hackable_diffusion_adapter/data/sudoku
bash prepare_sudoku_dataset.shOnce your environment is set up and the data is downloaded, you can kick off a minimal training run with the following commands.
Command line overrides are used to prevent compilation hangs and NCCL errors.
We fine-tune the model with rank 4 LoRA, using 2 canvases of size 128 each. We use batch size 2, peak learning rate 1e-4 and train for 2000 steps. We recommend using a machine with compute capacity at least that of 2 A100s.
From the parent dir of the gemma directory.
env XLA_FLAGS="--xla_disable_hlo_passes=constant_folding" \
NCCL_ALGO="Ring" \
NCCL_PROTO="LL128" \
NCCL_NVLS_ENABLE="0" \
NCCL_CUMEM_ENABLE="0" \
python3 -m kauldron.main \
--cfg=gemma/diffusion/hackable_diffusion_adapter/configs/sft_pubmedqa.py \
--cfg.workdir=$(pwd)/xp_dirWe fine-tune the model with rank 8 LoRA, using 1 canvas of size 256. We use batch size 8, peak learning rate 1.5e-4 and train for 2000 steps. We recommend using a machine with compute capacity at least that of 2 A100s.
From the parent dir of the gemma directory.
env XLA_FLAGS="--xla_disable_hlo_passes=constant_folding" \
NCCL_ALGO="Ring" \
NCCL_PROTO="LL128" \
NCCL_NVLS_ENABLE="0" \
NCCL_CUMEM_ENABLE="0" \
python3 -m kauldron.main \
--cfg=gemma/diffusion/hackable_diffusion_adapter/configs/sft_sudoku.py \
--cfg.workdir=$(pwd)/xp_dirWe fine-tune the model using full weight updates with 1 canvas of size 256. We use batch size 8, peak learning rate 1.5e-4 and train for 2000 steps. We also use Adafactor for optimization to reduce memory usage. We recommend using a machine with compute capacity at least that of 8 A100s.
From the parent dir of the gemma directory.
env XLA_FLAGS="--xla_disable_hlo_passes=constant_folding" \
NCCL_ALGO="Ring" \
NCCL_PROTO="LL128" \
NCCL_NVLS_ENABLE="0" \
NCCL_CUMEM_ENABLE="0" \
python3 -m kauldron.main \
--cfg=gemma/diffusion/hackable_diffusion_adapter/configs/sft_sudoku_full.py \
--cfg.workdir=$(pwd)/xp_dirEvaluation is run offline as a separate step after training. It loads a saved checkpoint, runs autoregressive (AR) sampling on the eval dataset, and reports task-specific metrics (e.g., accuracy for Sudoku, BLEU for PubMedQA).
From the parent dir of the gemma directory:
env XLA_FLAGS="--xla_disable_hlo_passes=constant_folding" \
XLA_PYTHON_CLIENT_PREALLOCATE="false" \
TF_FORCE_GPU_ALLOW_GROWTH="true" \
python3 -m gemma.diffusion.hackable_diffusion_adapter.eval_main \
--cfg=gemma/diffusion/hackable_diffusion_adapter/configs/sft_sudoku.py \
--task=sudoku \
--step=1000 \
--eval_names=sample_ar_steps64 \
--cfg.workdir=$(pwd)/xp_dir_sudoku_lora \
--cfg.eval_ds.batch_size=2 \
--cfg.aux.eval_num_batches=2 \
--cfg.aux.num_canvases=2| Flag | Description |
|---|---|
--cfg |
Path to the training config file (same one used for training). |
--task |
Task to evaluate: sudoku or pubmedqa. Determines which metrics are reported. |
--step |
Checkpoint step to evaluate. If omitted, the latest checkpoint is used. |
--eval_names |
Comma-separated list of evaluators to run (e.g. sample_ar_steps64). If omitted, all evaluators are run. |
--cfg.workdir |
Working directory that contains the training checkpoints. |
--cfg.eval_ds.batch_size |
Eval batch size (reduce if running out of memory). |
--cfg.aux.eval_num_batches |
Number of eval batches to process. Set to a small value for quick sanity checks, or omit to run over the full eval set. |
--cfg.aux.num_canvases |
Number of AR canvases to generate per example. |
The evaluators are generated automatically from the config. The naming convention is:
sample_ar_steps{N}— AR diffusion sampling withNdenoising stepssample_ar_steps{N}_early_stopping— Same as above, but with entropy-based early stopping.
Eval metrics are written to TensorBoard event files in the working directory. To view the results, launch TensorBoard pointing at the workdir:
tensorboard --logdir=$(pwd)/xp_dir_sudoku_loraMetrics for each evaluator appear under the corresponding eval name
(e.g. sample_ar_steps64). For Sudoku, key metrics include overall accuracy,
cell accuracy and difficulty-stratified results. For PubMedQA, look for
accuracy and BLEU scores.