Skip to content

Implementation of "Weight Averaging Improves Knowledge Distillation under Domain Shift" (ICCV 2023 OOD-CV Workshop)

License

Notifications You must be signed in to change notification settings

vorobeevich/distillation-in-dg

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Introduction

Official implementation of Weight Averaging Improves Knowledge Distillation under Domain Shift

Valeriy Berezovskiy, Nikita Morozov

ICCV 2023 Out Of Distribution Generalization in Computer Vision Workshop

Preparation

We highly recommend using conda for experiments.

After installation, make a new environment:

conda create python=3.10 --name dist --yes

conda activate dist

Install libs from requirements.txt:

conda install --file requirements.txt --yes

Torch versions may differ depending on your GPU.

Data

Load PACS and Office-Home datasets:

chmod 777 ./src/scripts/load_pacs.sh ./src/scripts/load_officehome.sh

./src/scripts/load_pacs.sh && ./src/scripts/load_officehome.sh

Usage

Hard label training:

python src/scripts/train.py --device [ID OF CUDA DEVICE] --config src/configs/[SELECT CONFIG TO RUN] --test [TEST_DOMAINS SETS]

Distillation:

python src/scripts/train.py --device [ID OF CUDA DEVICE] --config src/configs/[SELECT CONFIG TO RUN] --test [TEST_DOMAINS SETS] --dist

For PACS art_painting, photo, sketch, cartoon domains are available to select. You can select several at once: --test photo cartoon. For Office-Home art, clipart, product, real_world domains are available to select.

Before starting distillation, you need to train the teacher model.

Let's look at the config structure.

At the end of the config name there is a random seed. You can use any model or augmentation from torchvision. For the model, it is necessary to include parameters of the last linear layer. Also DeiT model is avaliable.

For the dataset, you must specify name (PACS, OfficeHome) and list of domains.

Also, you can change training parameters: any optimizer from torch.optim, batch size, SWAD parameters, and so on.

Visualization

All experiments were logged using wandb.

Citation

@article{berezovskiy2023weight,
  title={Weight Averaging Improves Knowledge Distillation under Domain Shift},
  author={Berezovskiy, Valeriy and Morozov, Nikita},
  journal={arXiv preprint arXiv:2309.11446},
  year={2023}
}

Contact

If you have any questions, feel free to contact us via email ([email protected] or [email protected]).

About

Implementation of "Weight Averaging Improves Knowledge Distillation under Domain Shift" (ICCV 2023 OOD-CV Workshop)

Topics

Resources

License

Stars

Watchers

Forks