Official PyTorch implementation of "OAMixer: Object-aware Mixing Layer for Vision Transformers" (CVPRW 2022) by Hyunwoo Kang*, Sangwoo Mo*, and Jinwoo Shin.
Our code is heavily built upon DeiT and timm repositories. We use the newer version of timm than DeiT to borrow the updated mixer implementations.
Our main contributions are in (a) models
directory that defines the base masked model class and specific instantiations for ViT, MLP-Mixer, and ConvMixer, and (b) transforms
directory that defines the paired transformations of image and corresponding patch labels (e.g., BigBiGAN, ReLabel).
Install required libraries.
pip install -r requirements.txt
Create BigBiGAN patch labels.
You can download the pretrained U-Net weights (e.g., trained on ImageNet) from the original repository.
Then, place the pretrained weights in patch_models/pretrained
.
python generate_mask.py --data-set [DATASET] --output_dir [OUTPUT_PATH]
Create ReLabel patch labels.
python3 generate_label.py [DATASET_PATH] [OUTPUT_PATH] --model dm_nfnet_f6 --pretrained --img-size 576 -b 32 --crop-pct 1.0
Train a baseline model.
python -m torch.distributed.launch --nproc_per_node=4 --use_env main.py \
--model deit_t --batch-size 64 --data-set imagenet --output_dir [OUTPUT_PATH]
Apply ReMixer to the baseline model.
[BASE_CODE_ABOVE] --mask-attention --patch-label relabel
Apply TokenLabeling (for both baseline model and ReMixer).
[BASE_CODE_ABOVE] --token-label
python main.py --eval --model deit_t --data-set imagenet --resume [OUTPUT_PATH]/checkpoint.pth