Skip to content

MMedPO: Aligning Medical Vision-Language Models with Clinical-Aware Multimodal Preference Optimization

License

Notifications You must be signed in to change notification settings

aiming-lab/MMedPO

Repository files navigation

💡 Overview

📦 Requirements

  1. Clone this repository and navigate to MMedPO folder
git clone https://github.com/aiming-lab/MMedPO.git
cd MMedPO
  1. Install Package: Create conda environment
conda create -n MMedPO python=3.10 -y
conda activate MMedPO
pip install --upgrade pip  # enable PEP 660 support
pip install -r requirements.txt
pip install trl
  1. Download the required model checkpoints LLaVA-Med-1.5 from huggingface.

  2. For model checkpoints, we released four checkpoints of MMedPO in the huggingface.

  3. For all the medical datasets, you need firstly apply for the right of access and then download the dataset.

🪧 Data Curation

We use MedKLIP to generate visual preference data. Use the following command or the script inference_attention-map_score.sh at ./scripts

python ./inference_attention-map_score.py \
    --config ./MedKLIP_config.yaml \
    --model_path /path/to/MedKLIP_model.pth \
    --dataset_name /dataset/name \
    --dataset_type caption \
    --image_root /path/to/dataset/image_folder \
    --annotation_save_root /path/to/save/annotation \
    --noised_image_save_root /path/to/save/noised_image \

🏋️ Train

Use the script train_dpo_visual-text.sh in ./scripts or the following command, make sure to specify the necessary data paths and the checkpoint saving location.

deepspeed --include localhost:0,1,2,3 ./train/dpo/train_dpo_visual-text.py \
    --model_name_or_path /path/to/llava-med_model_checkpoint \
    --deepspeed ./scripts/zero3.json \
    --version v1 \
    --lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5 \
    --data_path /path/to/data_json \
    --image_folder /path/to/img_folder \
    --vision_tower openai/clip-vit-large-patch14-336 \
    --mm_projector_type mlp2x_gelu \
    --mm_vision_select_layer -2 \
    --mm_use_im_start_end False \
    --mm_use_im_patch_token False \
    --image_aspect_ratio pad \
    --group_by_modality_length True \
    --bf16 True \
    --output_dir /path/to/output_checkpoint_saving_location \
    --num_train_epochs 3 \
    --per_device_train_batch_size 1\
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 1 \
    --evaluation_strategy "no" \
    --save_strategy "steps" \
    --save_steps 200 \
    --save_total_limit 1 \
    --learning_rate 1e-7 \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --report_to wandb \
    --tf32 True \
    --model_max_length 1024 \
    --gradient_checkpointing True \
    --dataloader_num_workers 4 \
    --lazy_preprocess True \

📚 Citation

@article{zhu2024mmedpo,
  title={MMedPO: Aligning Medical Vision-Language Models with Clinical-Aware Multimodal Preference Optimization},
  author={Zhu, Kangyu and Xia, Peng and Li, Yun and Zhu, Hongtu and Wang, Sheng and Yao, Huaxiu},
  journal={arXiv preprint arXiv:2412.06141},
  year={2024}
}

🙏 Acknowledgement

We use code from LLaVA-Med, RULE, MedKLIP. We thank the authors for releasing their code.

About

MMedPO: Aligning Medical Vision-Language Models with Clinical-Aware Multimodal Preference Optimization

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published