News: We add adversarial training result of RVT here !!
This repository contains PyTorch code for Robust Vision Transformers.
Note: Since the model is trained on our private platform, this transferred code has not been tested and may have some bugs. If you meet any problems, feel free to open an issue!
For details see our paper "Towards Robust Vision Transformer"
First, clone the repository locally:
git clone https://github.com/vtddggg/Robust-Vision-Transformer.git
Install PyTorch 1.7.0+ and torchvision 0.8.1+ and pytorch-image-models 0.3.2:
conda install -c pytorch pytorch torchvision
pip install timm==0.3.2
In addition, einops
and kornia
is required for using this implementation:
pip install einops
pip install kornia
We use 4 nodes with 8 gpus to train RVT-Ti
, RVT-S
and RVT-B
:
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=4 main.py --model rvt_tiny --data-path /path/to/imagenet --output_dir output --dist-eval
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=4 main.py --model rvt_small --data-path /path/to/imagenet --output_dir output --dist-eval
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=4 main.py --model rvt_base --data-path /path/to/imagenet --output_dir output --batch-size 32 --dist-eval
You can also finetune the pretrained model by adding --pretrained
.
If you want to train RVT-Ti*
, RVT-S*
or RVT-B*
, simply specify --model
as rvt_tiny_plus
, rvt_small_plus
or rvt_base_plus
, then add --use_patch_aug
to enable patch-wise augmentation.
News: The robustness evaluation now is supported!! Because of the environmental differences, the results of robustness may have the fluctuations of ±0.1~0.3% compared with paper results.
python main.py --eval --pretrained --model rvt_tiny --data-path /path/to/imagenet
python main.py --eval --pretrained --model rvt_tiny_plus --data-path /path/to/imagenet
python main.py --eval --pretrained --model rvt_small --data-path /path/to/imagenet
python main.py --eval --pretrained --model rvt_small_plus --data-path /path/to/imagenet
python main.py --eval --pretrained --model rvt_base --data-path /path/to/imagenet
python main.py --eval --pretrained --model rvt_base_plus --data-path /path/to/imagenet
To enable robustness evaluation, please add one of --inc_path /path/to/imagenet-c
, --ina_path /path/to/imagenet-a
, --inr_path /path/to/imagenet-r
or --insk_path /path/to/imagenet-sketch
to test ImageNet-C, ImageNet-A, ImageNet-R or ImageNet-Sketch.
If you want to test the accuracy under adversarial attackers, please add --fgsm_test
or --pgd_test
.
Model name | FLOPs | accuracy | weights |
---|---|---|---|
rvt_tiny |
1.3 G | 78.4 | link |
rvt_small |
4.7 G | 81.7 | link |
rvt_base (ImageNet-22k) |
17.7 G | 83.4 | link |
rvt_tiny* |
1.3 G | 79.3 | link |
rvt_small* |
4.7 G | 81.8 | link |
rvt_base* (ImageNet-22k) |
17.7 G | 83.6 | link |
Model name | clean accuracy | PGD accuracy | weights |
---|---|---|---|
adv_deit_tiny |
52.05 | 26.55 | link |
adv_rvt_tiny |
54.91 | 28.1 | link |
To test these models, run following commands:
python adv_test.py --model deit_tiny_patch16_224 --ckpt_path adv_deit_tiny.pth
python adv_test.py --model rvt_tiny --ckpt_path adv_rvt_tiny.pth