Official implementation of Weight Averaging Improves Knowledge Distillation under Domain Shift
Valeriy Berezovskiy, Nikita Morozov
ICCV 2023 Out Of Distribution Generalization in Computer Vision Workshop
We highly recommend using conda for experiments.
After installation, make a new environment:
conda create python=3.10 --name dist --yes
conda activate dist
Install libs from requirements.txt:
conda install --file requirements.txt --yes
Torch versions may differ depending on your GPU.
Load PACS and Office-Home datasets:
chmod 777 ./src/scripts/load_pacs.sh ./src/scripts/load_officehome.sh
./src/scripts/load_pacs.sh && ./src/scripts/load_officehome.sh
Hard label training:
python src/scripts/train.py --device [ID OF CUDA DEVICE] --config src/configs/[SELECT CONFIG TO RUN] --test [TEST_DOMAINS SETS]
Distillation:
python src/scripts/train.py --device [ID OF CUDA DEVICE] --config src/configs/[SELECT CONFIG TO RUN] --test [TEST_DOMAINS SETS] --dist
For PACS art_painting, photo, sketch, cartoon domains are available to select. You can select several at once: --test photo cartoon. For Office-Home art, clipart, product, real_world domains are available to select.
Before starting distillation, you need to train the teacher model.
Let's look at the config structure.
At the end of the config name there is a random seed. You can use any model or augmentation from torchvision. For the model, it is necessary to include parameters of the last linear layer. Also DeiT model is avaliable.
For the dataset, you must specify name (PACS, OfficeHome) and list of domains.
Also, you can change training parameters: any optimizer from torch.optim, batch size, SWAD parameters, and so on.
All experiments were logged using wandb.
@article{berezovskiy2023weight,
title={Weight Averaging Improves Knowledge Distillation under Domain Shift},
author={Berezovskiy, Valeriy and Morozov, Nikita},
journal={arXiv preprint arXiv:2309.11446},
year={2023}
}
If you have any questions, feel free to contact us via email ([email protected] or [email protected]).