Skip to content
/ G2DM Public

Generalizing to unseen domains via distribution matching

License

Notifications You must be signed in to change notification settings

belaalb/G2DM

Repository files navigation

Open In Colab Binder

Generalizing to unseen Domains via Distribution Matching (G2DM)

Code to reproduce the experiments from the preprint: Generalizing to unseen domains via distribution matching

Simple use

Clone or download the repository and run all_in_one.ipynb! Or Open In Colab

Binder

Click on Binder and run all_in_one.ipynb. If you get memory error on mybinder try to decrease batch_size or run the jupyter on the colab or a deep learning rig. ;)

Requirements

Python >= 3.6
pytorch >= 1.2
torchvision >= 0.4.1
Scikit-learn >= 0.19
h5py
tqdm
pandas
seaborn

To Install requirements:

  pip install -r requirements.txt

Download VLCS

Download it from http://www.mediafire.com/file/7yv132lgn1v267r/vlcs.tar.gz/file, extract, move it to ./data/vlcs/ prepared_data/ and then

Download pre-trained AlexNet

Download it from https://drive.google.com/file/d/1wUJTH1Joq2KAgrUDeKJghP1Wf7Q9w4z-/view?usp=sharing and move it to ./

Or just run:

python download_alexnet_vlcs.py

Table 1

Example considering Caltech101 as target domain.

Running ours

cd vlcs-ours
python train.py --lr-task 0.001 --lr-domain 0.005 --l2 0.005 --smoothing 0.2 --lr-threshold 0.0001 --factor 0.3 --alpha 0.8 --rp-size 3500 --patience 60 --warmup-its 300 --source1 PASCAL --source2 LABELME --source3 SUN --target CALTECH

Running ERM

cd vlcs-ours
python baseline_train.py --lr 0.001 --l2 0.00001 --patience 120 --source1 PASCAL --source2 LABELME --source3 SUN --target CALTECH

Running IRM

cd IRM-vlcs
python train.py --lr 0.0004898536566546834 --l2 0.00221589136 --penalty_weight 91257.18613115903 --penalty_anneal_epochs 78 --source1 PASCAL --source2 LABELME --source3 SUN --target CALTECH

Table 2

Example considering SUN09 and Caltech-101 as target domains.

Running ours

cd vlcs-2sources
python train.py --lr-task 0.001 --lr-domain 0.005 --l2 0.005 --smoothing 0.2 --lr-threshold 0.0001 --factor 0.3 --alpha 0.8 --rp-size 3500 --patience 60 --warmup-its 300 --source1 PASCAL --source2 LABELME --target1 SUN --target2 CALTECH

Running ERM

cd vlcs-2sources
python baseline_train.py --lr 0.001 --l2 0.00001 --patience 120 --source1 PASCAL --source2 LABELME --target1 SUN --target2 CALTECH

Table 3

Example considering art painting as target domain.

Download PACS

Prepare hdf files for PACS

cd data/pacs
python prep_hdf.py --train-val-test train
python prep_hdf.py --train-val-test val
python prep_hdf.py --train-val-test test

Running ours

cd pacs-ours
python train.py --lr-task 0.01 --lr-domain 0.0005 --l2 0.0005 --smoothing 0.2 --lr-threshold 0.00001 --factor 0.5 --alpha 0.8 --rp-size 1000 --patience 80 --warmup-its 300 --source1 photo --source2 cartoon --source3 sketch --target art_painting

Running ERM

cd pacs-ours
python baseline_train.py --lr 0.001 --l2 0.0001 --momentum 0.9 --patience 120 --source1 photo --source2 cartoon --source3 sketch --target art_painting

Running IRM

cd IRM-pacs
python train.py --lr 0.0004898536566546834 --l2 0.00221589136 --penalty_weight 91257.18613115903 --penalty_anneal_epochs 78 --source1 photo --source2 cartoon --source3 sketch --target art_painting

Figure 3

cd pacs-ours
python h_divergence.py --batch-size 500 --encoder-path path-to-trained-model --architecture 'adversarial'

Table 4

cd pacs-resnet
python train.py --train-model resnet18 --smoothing 0 --train-mode hv --nadir-slack 2.5 --alpha 0.8 --lr-task 0.01 --lr-domain 0.005 --patience 20 --l2 0.0005 --rp-size 512 --source1 photo --source2 cartoon --source3 sketch --target art_painting

About

Generalizing to unseen domains via distribution matching

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published