We provide the Pytorch implementation of our MIDL 2023 submission "Memory-Efficient 3D Denoising Diffusion Models for Medical Image Processing"
Check out the project page!
The implementation is based on Diffusion Models for Medical Anomaly Detection and openai/guided-diffusion (MIT-License).
Install the necessary python packages as defined in environment.yaml
.
We recommend using mambaforge.
You can create the environment using
mamba env create -n your_env_name --file environment.yaml
If you run into problems, you can try using different versions of these packages.
You can use the run.sh
file to run the training as well as the sampling for the different models.
We have broken out the relevant parameters on the top of the file, adjust them corresponding to
what model you'd like to train or sample from, and what part of the data.
A visualization of the training and sampling process is done using Tensorboard.
The model checkpoints will be saved in a subdirectory of the runs
folder, generated by tensorboard.
To view and compare the different runs, run tensorboard --logdir=runs --bind_all
,
and open the provided link in your browser.
We probide a torch.utils.data.Dataset
implementation for
BraTS2020 data, normalized as
described in the paper. The implementation assumes that the data is stored in a directory structure
like
root
dataroot
000001
brats_train_001_t1_000_w.nii.gz
brats_train_001_t1ce_000_w.nii.gz
brats_train_001_t2_000_w.nii.gz
brats_train_001_flair_000_w.nii.gz
brats_train_001_seg_000_w.nii.gz
000002
brats_train_002_t1_000_w.nii.gz
brats_train_002_t1ce_000_w.nii.gz
...
Copyright 2023 Center of Image Analysis and Navigation, University of Basel