Public PyTorch implementation for our paper Unified Cross-Modal Image Synthesis withHierarchical Mixture of Product-of-Experts, currently under review.
If you find this code useful for your research, please cite the following paper:
@article{dorent2024unified,
title={Unified Cross-Modal Image Synthesis with Hierarchical Mixture of Product-of-Experts},
author={Dorent, Reuben and Haouchine, Nazim and Golby, Alexandra and Frisken, Sarah and Kapur, Tina and Wells, William},
journal={arXiv preprint arXiv:2410.19378},
year={2024}
}
We propose a deep mixture of multimodal hierarchical variational auto-encoders called MMHVAE that synthesizes missing images from observed images in different modalities. MMHVAE’s design focuses on tackling four challenges:
- creating a complex latent representation of multimodal data to generate high-resolution images
- encouraging the variational distributions to estimate the missing information needed for cross-modal image synthesis
- learning to fuse multimodal information in the context of missing data
- leveraging dataset-level information to handle incomplete data sets at training time.
Example of multimodal synthesis using a) an ultrasound or b) a T2 scan as input. Three MR sequences (FLAIR, contrast-enhanced T1, T2) and ultrasound images are generated.
Once all the python packages are installed (see below), and download the models you can simply test MMHVAE on your own data with:
python pred.py --model_dir pretrained/mmhvae_f0/ --input test/ --output output_test/
-
Clone this repository.
-
Create a virtual environment (i.e., with pip or conda) and install all the required packages.
pip install -r requirements.txt -
Optional: Download the pre-trained models
Example of training command:
WEIGHT=0.025
FOLD=0
python train.py \
--model_dir ./models/fold$FOLD/mmhvae_$WEIGHT/ \
--dataset_split splits/split_empty_$FOLD.csv \
--path_data $WORK/data/TPAMI/data_unnorm/ \
--path_data_norm $WORK/data/TPAMI/data_norm/ \
--w_dis $WEIGHT
Example of training command:
WEIGHT=0.025
FOLD=0
python inference.py \
--model_dir ./models/fold$FOLD/mmhvae_$WEIGHT/ \
--dataset_split splits/split_empty_$FOLD.csv \
--path_data $WORK/data/TPAMI/data_unnorm/ \
--path_data_norm $WORK/data/TPAMI/data_norm/ \
--save_images
