Skip to content

dsaragih/diffuse-gen

Repository files navigation

Using diffusion models to generate synthetic labeled data for medical image segmentation

Repository for the paper Using diffusion models to generate synthetic labeled data for medical image segmentation.

Abstract

Purpose

Medical image analysis has become a prominent area where machine learning has been applied. However, high-quality, publicly available data are limited either due to patient privacy laws or the time and cost required for experts to annotate images. In this retrospective study, we designed and evaluated a pipeline to generate synthetic labeled polyp images for augmenting medical image segmentation models with the aim of reducing this data scarcity.

Methods

We trained diffusion models on the HyperKvasir dataset, comprising 1000 images of polyps in the human GI tract from 2008 to 2016. Qualitative expert review, Fréchet Inception Distance (FID), and Multi-Scale Structural Similarity (MS-SSIM) were tested for evaluation. Additionally, various segmentation models were trained with the generated data and evaluated using Dice score (DS) and Intersection over Union (IoU).

Results

Our pipeline produced images more akin to real polyp images based on FID scores. Segmentation model performance also showed improvements over GAN methods when trained entirely, or partially, with synthetic data, despite requiring less compute for training. Moreover, the improvement persists when tested on different datasets, showcasing the transferability of the generated images.

Conclusions

The proposed pipeline produced realistic image and mask pairs which could reduce the need for manual data annotation when performing a machine learning task. We support this use case by showing that the methods proposed in this study enhanced segmentation model performance, as measured by Dice and IoU scores, when trained fully or partially on synthetic data.

Running the code

Environment

Pytorch 1.12.1, Python 3.9

$ conda create --name diffgen python=3.9
$ conda activate diffgen
$ pip install ftfy regex matplotlib lpips kornia opencv-python color-matcher blobfile scikit-learn pytorch_msssim
$ pip install torch==1.12.1 torchvision==0.13.1
$ pip install git+https://github.com/openai/CLIP.git
$ (conda or pip) install mpi4py 

Model download

To generate images, please download the pre-trained diffusion model(s)

Full Model LINK ~2 GB

Cluster (20) Models LINK ~40GB

download the model into ./checkpoints folder.

Clusters

The .pkl file containing the cluster assignments may be found in the clustering folder. Note that cluster assignments for 10, 20, 30, and 40 are provided, as well as a simple notebook to generate the clusters.

Unfortunately, we aren't able to provide pre-trained models for $k \neq 20$, but the training code is provided below.

Training Models

# Cluster Models

python ~/diffusion-gen/guided_diffusion/BOI_train.py \
    --data_dir ~/diffusion-gen/guided_diffusion/segmented-images \
    --cluster_file ~/diffusion-gen/clusters.pkl \
    --cluster_index 1 \
    --image_size 256 \
    --out_dir ~/diffusion-gen/checkpoints \
    --batch_size 1

# Full Model

python ~/diffusion-gen/guided_diffusion/image_train.py \
    --data_dir ~/diffusion-gen/guided_diffusion/segmented-images \
    --image_size 256 \
    --out_dir ~/diffusion-gen/checkpoints \
    --batch_size 1

--cluster_index may be changed to any cluster index from the cluster file

Sample images from cluster models

python ~/diffusion-gen/main.py \
    --cluster_path ~/diffusion-gen/clusters.pkl \
    --output_path ~/diffusion-gen/cluster_image_samples \
    --model_path ~/diffusion-gen/checkpoints/cluster_models/model110000.pt \
    --diff_iter 100 \
    --timestep_respacing 200 \
    --skip_timesteps 80 \
    --model_output_size 256 \
    --num_samples 1 \
    --batch_size 1 \
    --use_noise_aug_all \
    --use_colormatch \
    -fti -sty -inp -spi

The cluster path, output path, and model path can be changed to your own path.

Note the format of the model path: each model will be of the form cluster_{cluster index}_model_{model iteration}.pt. For example, cluster_0_model_110000.pt is the model for cluster 0 at iteration 110000. In the main.py file, we loop through all the cluster models with the given iteration number and generate samples for each cluster.

-fti selects a random image from the cluster to use as the target image for image-guided generation

-sty applies styling

-inp uses the inpainting technique

-spi instead of just sampling {num_samples} images, we sample {num_samples} images from each image

Sample images from full model

python ~/diffusion-gen/main.py \
    --data_dir ~/diffusion-gen/guided_diffusion/segmented-images/masked-images \
    --output_path ~/diffusion-gen/image_samples \
    --model_path ~/diffusion-gen/logs/256models/model200000.pt \
    --diff_iter 100 \
    --timestep_respacing 200 \
    --skip_timesteps 80 \
    --model_output_size 256 \
    --num_samples 1 \
    --batch_size 1 \
    --use_noise_aug_all \
    --use_colormatch \
    -fti -sty -inp -spi

Refer to above section for clarification on the arguments.

Our source code rely on Blended-diffusion, guided-diffusion, flexit, splicing vit, and DiffuseIT

Citation

If you find this work useful, please consider citing:

@article{saragihUsingDiffusionModels2024,
  title = {Using Diffusion Models to Generate Synthetic Labeled Data for Medical Image Segmentation},
  author = {Saragih, Daniel G. and Hibi, Atsuhiro and Tyrrell, Pascal N.},
  year = {2024},
  month = aug,
  journal = {International Journal of Computer Assisted Radiology and Surgery},
  volume = {19},
  number = {8},
  pages = {1615--1625},
  issn = {1861-6429},
  doi = {10.1007/s11548-024-03213-z},
  urldate = {2024-08-19},
}

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published