Skip to content

PyTorch implementation of unsupervised brain anomaly detection the paper entitled: "REFLECT: Rectified Flows for Efficient Brain Anomaly Correction Transport."

Notifications You must be signed in to change notification settings

farzad-bz/REFLECT

Repository files navigation

✨ REFLECT ✨

A PyTorch Implementation for Unsupervised Brain Anomaly Detection

This repository hosts the official PyTorch implementation for our paper accepted in MICCAI2025:
"REFLECT: Rectified Flows for Efficient Brain Anomaly Correction Transport".


🎨 Approach

REFLECT Method


⚙️ Setup

🛠️ Environment

Our experiments run on Python 3.11. Install all the required packages by executing:

pip3 install -r requirements.txt

📁 Datasets

Prepare your data as follows:

  1. Data Registration & Preprocessing:

    • Register with MNI_152_1mm.
    • Preprocess, normalize, pad and extract axial slices.
  2. Dataset Organization:

    • Ensure training and validation sets contain only normal, healthy data.
    • Test set should include abnormal slices.
    • Organize your files using this structure:
    ├── Data
        ├── train
        │   ├── {train_image_id}-slice_{slice_idx}-{modality}.png
        │   ├── {train_image_id}-slice_{slice_idx}-brainmask.png
        │   └── ...
        └── test
            ├── {test_image_id}-slice_{slice_idx}-{modality}.png
            ├── {test_image_id}-slice_{slice_idx}-brainmask.png
            ├── {test_image_id}-slice_{slice_idx}-segmentation.png
            └── ...
    

🔧 Pretrained Weights & VAE Fine-Tuning

Pretrained VAE Models

To jumpstart your experiments, we provide pretrained weights adapted for 1-channel medical brain images. These models are available on HuggingFace.

Train & Fine-Tune VAE

If you prefer to train your own VAE from scratch, please refer to the LDM-VAE repository for detailed instructions.


🔗 DTD Embedding Download

The training script requires a precomputed DTD embedding file.

  • Download the DTD embeddings based on your desired VAE model (klf4 or klf8):

  • Copy the downloaded file to the directory you specify with the --dtd-dir argument.


🚄 Training REFLECT-1

To train the REFLECT-1 model, run the following command. This example uses a UNet_M architecture and integrates a pretrained VAE (with scale factor 8) for the T1 modality of the BraTS dataset:

torchrun train_REFLECT.py \
            --dataset BraTS \
            --model UNet_M \
            --image-size 256 \
            --vae kl_f8 \
            --modality T1 \
            --dtd-dir . \
            --data-dir .
            

Where:

  • --dataset: BraTS or ATLAS
  • --model: UNet_XS, UNet_S, UNet_M, UNet_L, or UNet_XL
  • --vae: kl_f8 or kl_f4
  • --modality: For BraTS: T1, T2, FLAIR, or T1CE; for ATLAS: T1
  • --dtd-dir: Path to the directory containing the DTD embedding file.
  • --data-dir: Path to the root data directory.

🚄 Training REFLECT-2

To train the REFLECT-2 model, first ensure you have completed REFLECT-1 training. Then, launch REFLECT-2 training as shown below.

  • Note: train_REFLECT-2.py automatically loads the required arguments from the YAML config file found in the parent directory of the specified REFLECT-1 model path, so you do not need to specify them manually:
torchrun train_REFLECT-2.py \
            --dtd-dir . \
            --data-dir . \
            --REFLECT-1-path ./REFLECT_BraTS_UNet_M_T1_256_kl_f8/006-UNet_M-T1/checkpoints/last.pt

where

  • --REFLECT-1-path: Path to the trained REFLECT-1 model checkpoint.

🚦 Evaluating REFLECT

To evaluate a trained REFLECT model, use the following command. Note: evaluate_REFLECT.py also loads its configuration and arguments from the YAML file located in the parent directory of the given model checkpoint path. The script computes four evaluation metrics and saves per-image visualizations in the parent folder of the model path:

torchrun evaluate_REFLECT.py \
            --data-dir . \
            --model-path ./REFLECT_BraTS_UNet_M_T1_256_kl_f8/001-UNet_M-T1/checkpoints/last.pt

where

  • --model-path: Path to the trained model checkpoint (either REFLECT-1 or REFLECT-2).

📸 Sample Results

Sample Results


📚 Citation & Reference

If you find REFLECT useful in your research, please cite our work:

@article{beizaee2025reflect,
  title={REFLECT: Rectified Flows for Efficient Brain Anomaly Correction Transport},
  author={Beizaee, Farzad and Hajimiri, Sina and Ayed, Ismail Ben and Lodygensky, Gregory and Desrosiers, Christian and Dolz, Jose},
  journal={arXiv preprint arXiv:2508.02889},
  year={2025}
}

About

PyTorch implementation of unsupervised brain anomaly detection the paper entitled: "REFLECT: Rectified Flows for Efficient Brain Anomaly Correction Transport."

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages