- CLIP Models
- Installation
- Quick Start
- Directories
- Extract Features
- Generate Bottlenecks from Medical Documents
- Train Grounding Functions
- Baselines
We release the two CLIP models we trained for X-ray and Skin Lesion images on huggingface.
- WhyXrayCLIP 🩻 : https://huggingface.co/yyupenn/whyxrayclip
- WhyLesionCLIP 👍🏽 : https://huggingface.co/yyupenn/whylesionclip
After cloning the repo, you can install the required dependencies and download the data by running the following commands:
git clone https://github.com/YueYANG1996/KnoBo.git
cd KnoBo
sh setup.sh
To get the results of KnoBo on X-ray datasets, you can run the following command:
python modules/cbm.py \
--mode binary \
--bottleneck PubMed \
--number_of_features 150 \
--add_prior True \
--modality xray \
--model_name whyxrayclip \
The output will be saved to ./data/results/
. You can change the --modality
to skin
and --model_name
to whylesionclip
to get the results on Skin Lesion datasets.
-
data/
: Contains the data for all experiments.data/bottlenecks/
: Contains the concept bottleneck created using medical documents.data/datasets/
: This contains the splits for all datasets. You may need to download the images of each dataset from its original sources. Please refer to the DATASETS.md for more details.data/features/
: Contains the features extracted from different models.data/grounding_functions/
: Contains the grounding functions for each concept in the bottleneck.data/results/
: Contains the results of all experiments.
-
modules/
: Contains the scripts for all experiments.modules/cbm.py
: Contains the script for the running linear-based models, including KnoBo, linear probing, and PCBM.modules/extract_features.py
: Contains the script for extracting image features using different models.modules/train_grounding.py
: Contains the script for training the grounding functions for each concept in the bottleneck.modules/end2end.py
: Contains the script for training the end-to-end model, including ViT and DenseNet.modules/LSL.py
: Contains the script for fine-tuning CLIP with knowledge (Language-shaped Learning).modules/models.py
: Contains the models used in the experiments.modules/utils.py
: Contains the utility functions.
After running the setup.sh
, you should have the features extracted from the two CLIP models we trained in the data/features/
directory. If you want to extract features using other models, you can run the following command:
python modules/extract_features.py \
--dataset_name <NAME OF THE DATASET> \
--model_name <NAME OF THE MODEL> \
--image_dir <PATH TO THE IMAGE DIRECTORY> \
The supported models are listed here. We provide a bash script extract_features.sh
to extract features for all datasets using the two CLIP models we trained.
We build the retrieval-based concept bottleneck generation pipeline based on MedRAG. You need to first clone our forked version and set up the environment by running the following commands:
git clone https://github.com/YueYANG1996/MedRAG.git
cd MedRAG
sh setup.sh
It may take a while since it needs to download the 5M PubMed documents (29.5 GB). After setting up the environment, you can test the RAG system by running the test.py
.
To generate the concept bottleneck from medical documents, you can run the following command:
python concept_generation.py \
--modality <xray or skin> \
--corpus_name <NAME OF THE CORPUS> \
--number_of_concepts <NUMBER OF CONCEPTS> \
--openai_key <OPENAI API KEY> \
For the --corpus_name,
you can choose from PubMed_all
(this is our version of PubMed with all paragraphs), PubMed
(this is MedRAG's original version of PubMed, which only has abstracts), Textbooks,
StatPearls
and Wikipedia
. The generated bottleneck will be saved to ./data/bottlenecks/<modality>_<corpus>_<number_of_concepts>.txt
.
Annotate concepts: You can annotate clinical reports for each concept in the bottleneck by running the following command:
python annotate_question.py \
--annotator <t5 of gpt4> \
--modality <xray or skin> \
--bottleneck <NAME OF THE BOTTLENECK> \
--number_of_reports <NUMBER OF REPORTS TO ANNOTATE> \
--openai_key <OPENAI API> \
The default LLM for annotation is Flan-T5-XXL. You can change it to GPT-4 by setting --annotator gpt4
(warning: this may cost a lot of money). The default number of reports to annotate is 1000. The annotated reports will be saved to ./data/concept_annotation_<modality>/annotations_<annotator>/
.
To train the grounding functions for each concept in the bottleneck, you can run the following command:
python modules/train_grounding.py \
--modality <xray or skin> \
--bottleneck <NAME OF THE BOTTLENECK> \
Each grounding function is a binary classifier that predicts whether the concept is present in the image. The output will be saved to ./data/grounding_functions/<modality>/<concept>/
.
-
Linear Probing:
python modules/cbm.py --mode linear_probe --modality <xray or skin> --model_name <vision backbone>
. -
PCBM-h:
python modules/cbm.py --mode pcbm --bottleneck PubMed --number_of_features 150 --modality <xray or skin> --model_name <vision backbone>
. -
End-to-End:
python modules/end2end.py --modality <xray or skin> --model_name <vit or densenet>
. -
LSL: You need to first fine-tune the CLIP model with knowledge using the following command:
python modules/LSL.py \ --modality <xray or skin> \ --clip_model_name <base model, e.g., whyxrayclip> \ --bottleneck <NAME OF THE BOTTLENECK> \ --image_dir <PATH TO THE IMAGE DIRECTORY> \
Then, extract the features using the fine-tuned CLIP model and get the final results same as linear probing:
python modules/cbm.py --mode linear_probe --modality <xray or skin> --model_name <fine-tuned vision backbone>
. We provide the models we fine-tuned on PubMed in thedata/model_weights/
directory.
Please cite our paper if you find our work useful!
@article{yang2024textbook,
title={A Textbook Remedy for Domain Shifts: Knowledge Priors for Medical Image Analysis},
author={Yue Yang and Mona Gandhi and Yufei Wang and Yifan Wu and Michael S. Yao and Chris Callison-Burch and James C. Gee and Mark Yatskar},
journal={arXiv preprint arXiv:2405.14839},
year={2024}
}