Code for CVPR 2024 paper, "VoCo: A Simple-yet-Effective Volume Contrastive Learning Framework for 3D Medical Image Analysis"
Authors: Linshan Wu, Jiaxin Zhuang, and Hao Chen
This work presents VoCo, a simple-yet-effective contrastive learning framework for pre-training large scale 3D medical images. Our 10k CT images pre-training model are available. Our 160k CT images pre-training models are available!
Our extention is at Large-Scale-Medical, which provides stronger models, larger-scale datasets, various training recipes, and more downstream tasks!!!
Self-Supervised Learning (SSL) has demonstrated promising results in 3D medical image analysis. However, the lack of high-level semantics in pre-training still heavily hinders the performance of downstream tasks. We observe that 3D medical images contain relatively consistent contextual position information, i.e., consistent geometric relations between different organs, which leads to a potential way for us to learn consistent semantic representations in pre-training. In this paper, we propose a simple-yet-effective Volume Contrast (VoCo) framework to leverage the contextual position priors for pre-training. Specifically, we first generate a group of base crops from different regions while enforcing feature discrepancy among them, where we employ them as class assignments of different regions. Then, we randomly crop sub-volumes and predict them belonging to which class (located at which region) by contrasting their similarity to different base crops, which can be seen as predicting contextual positions of different sub-volumes. Through this pretext task, VoCo implicitly encodes the contextual position priors into model representations without the guidance of annotations, enabling us to effectively improve the performance of downstream tasks that require high-level semantics. Extensive experimental results on six downstream tasks demonstrate the superior effectiveness of VoCo.
Our checkpoints 10K pre-trained checkpoint is available at VoCo_10k.pt. More results are comming.
Method | Dataset | Pre-trained model | Training log | BTCV |
---|---|---|---|---|
VoCo | 10k CT | VoCo_10k.pt | Part of Pre-training log | 84.51 |
import torch
import argparse
from monai.networks.nets import SwinUNETR
parser = argparse.ArgumentParser(description="Swin UNETR")
parser.add_argument("--roi_x", default=roi, type=int, help="roi size in x direction")
parser.add_argument("--roi_y", default=roi, type=int, help="roi size in y direction")
parser.add_argument("--roi_z", default=roi, type=int, help="roi size in z direction")
parser.add_argument("--feature_size", default=48, type=int, help="feature size")
parser.add_argument("--in_channels", default=1, type=int, help="number of input channels")
parser.add_argument("--out_channels", default=14, type=int, help="number of output channels")
parser.add_argument("--use_checkpoint", default=True, help="use gradient checkpointing to save memory")
args = parser.parse_args()
model = SwinUNETR(
img_size=(args.roi_x, args.roi_y, args.roi_z),
in_channels=args.in_channels,
out_channels=args.out_channels,
feature_size=args.feature_size,
use_checkpoint=args.use_checkpoint,
use_v2=True)
model_dict = torch.load(args.pretrained_checkpoint, map_location=torch.device('cpu'))
state_dict = model_dict
if "module." in list(state_dict.keys())[0]:
print("Tag 'module.' found in state dict - fixing!")
for key in list(state_dict.keys()):
state_dict[key.replace("module.", "")] = state_dict.pop(key)
if "swin_vit" in list(state_dict.keys())[0]:
print("Tag 'swin_vit' found in state dict - fixing!")
for key in list(state_dict.keys()):
state_dict[key.replace("swin_vit", "swinViT")] = state_dict.pop(key)
model.load_state_dict(state_dict, strict=False)
print("Using pretrained voco ema self-supervised Swin UNETR backbone weights !")
First, you need to download the pre-training dataset. The 10k dataset are all open-source and you can download yourself. Or you can download it in our hugging face repo. Note: 10k dataset is collected by Dr. Jiaxin Zhuang
├── data
├── BTCV
├── TCIAcovid19
├── Luna16-jx
├── stoic21
├── Totalsegmentator_dataset
├── Flare23
├── LiDC
└── HNSCC_convert_v1
(1) Note that in this repo, we present the version of our 10k pre-training, thus some details may be different to our paper.
(2) To accerlate the training, we use "Persistentdataset" to pre-cache dataset, which requires extra storage. It is important in our codes. If you don't have enough storage, you can change it back in "utils/data_utils.py".
To pre-train:
sh train.sh
Our finetune codes will soon be available, or you can directly use the codes in MONAI.
More finetune implementation are in preparation!
We thank MONAI for part of their codes.
If you find this repo useful for your research, please consider citing the paper as follows:
@article{wu2024large,
title={Large-Scale 3D Medical Image Pre-training with Geometric Context Priors},
author={Wu, Linshan and Zhuang, Jiaxin and Chen, Hao},
journal={arXiv preprint arXiv:2410.09890},
year={2024}
}
@InProceedings{voco,
author = {Wu, Linshan and Zhuang, Jiaxin and Chen, Hao},
title = {VoCo: A Simple-yet-Effective Volume Contrastive Learning Framework for 3D Medical Image Analysis},
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
month = {June},
year = {2024},
pages = {22873-22882}
}