We propose a method that leverages initial imprecise segmentation masks to provide domain knowledge within prompt-based segmentation foundation models. Our pipeline enables the Segment Anything Model (SAM) for pseudo label refinement in a semi-supervised setting. With our automatic generation of prompting, SAM refines initial segmentation masks derived from a limited amount of labelled data and thus provides pseudo labels, enabling unlabelled data for training.
This repository contains code from the official SAM repository. Please give them credit for sharing their amazing work.
This instruction will explain how to reproduce the results exemplary for the GrazPedWri Dataset.
For the other dataset, please refer to the branch addional_ds
and for the implementation of the comparison methods, please refer to the branch master
.
We rely on Clear-ML to manage and store our different models.
While it is possible, to rewrite the code to not use Clear-ML and store the models locally, we recommend to use Clear-ML since their free-plan is sufficient to reproduce our results.
Our code is implemented in PyTorch. Please use the provided yaml (environment.yml) file to create the environment.
conda env create -f environment.yml
Please download the dataset using the provided link in the original paper and preprocess it with their provided notebooks to obtain the 8-bit images. After this, please use our provided preprocessing script scripts/copy_and_process_graz_imgs
to create the homogeneous dataset (all images flipped to left).
Our radiologists' segmentation mask of 64 representative images were annotated in CVAT and are stored in the annotations_*.xml
files.
The decoding is done by our custom utils.cvat_parser
, which is already used in our PyTorch dataset implementation scripts.seg_gratpedwri_dataset.py
.
To train our initial U-Net python -m unet_training.training --gpu_id 0
leaving all the hyperparameters on default.
Next, we predict our initial, unrefined segmentation masks for our unlabelled data by using python -m scripts.save_segmentations
, where you have to adjust the model id in line 19 to your clear-ml model id (in the clearml experiments: artefacts → output models → bone_segmentator).
To use SAM to refine the segmentation masks, we have to set up two things first:
- Download the model checkpoint for the ViT-H from the official repository and place it in the
data
folder. - Precompute SAM's image embedding to speed up the refinement process (believe me, you will rerun SAM on the same image at some point, and then you will be happy to have the embeddings precomputed). To do so, run
python -m scripts.save_refined_segmentations
. Please use the same model id as before to load its predictions as initial segmentations.
As a last step, we can train the final U-Net python -m unet_training.training_on_pseudo_labels --gpu_id 0 --pseudo_label sam --prompt1st box --prompt2nd pos_points neg_points --num_train_samples 43
leaving all others hyperparameters on default.
To evaluate the final model, you can reuse the python -m scripts.save_segmentations
script by adjusting the model id to the final one.