diff --git a/readme.md b/readme.md index f86c987..9628a80 100644 --- a/readme.md +++ b/readme.md @@ -1,50 +1,117 @@ -## 环境配置: - -conda create -n DynamicVit python=3.6 - -conda activate DynamicVit - -conda deactivate - -conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 -c pytorch - -pip3 install timm==0.4.5 +# SPViT: Enabling Faster Vision Transformers via Latency-aware Soft Token Pruning +This repository contains PyTorch training code for the ECCV 2022 paper. -## 下载prertained model -见 download_pretrain.sh +[arXiv](https://arxiv.org/abs/2211.10801) | [ECCV 2022](https://www.ecva.net/papers/eccv_2022/papers_ECCV/papers/136710618.pdf) +

+
+ Comparison of different models with various accuracy-training time trade-off.. +

+## Usage -## 命令 +### Environment Settings -举例:跑deit-small, 用3keep+senet的代码. - -python3 -u -m torch.distributed.launch --nproc_per_node=8 --use_env main_l2_vit_3keep_senet.py --output_dir logs/3keep_senet --arch deit_small --input-size 224 --batch-size 256 --data-path /data/ImageNet_new/ --epochs 30 --dist-eval --distill --base_rate 0.7 2>&1 | tee -i 3keep_senet.log +``` +conda create -n DynamicVit python=3.6 -跑deit-base +conda activate DynamicVit -python3 -u -m torch.distributed.launch --nproc_per_node=8 --use_env main_l2_vit_3keep_senet.py --output_dir logs/deit_base_3keep_senet_256_60_5e-4 --arch deit_base --input-size 224 --batch-size 256 --data-path /data/ImageNet_new/ --epochs 60 --dist-eval --distill --base_rate 0.7 2>&1 | tee -i deit_base_3keep_senet_256_60_5e-4.log +conda deactivate +conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 -c pytorch -## 调参 link +pip3 install timm==0.4.5 +``` + +### Data preparation +download and extract ImageNet images from http://image-net.org/. The directory structure should be + +``` +│ILSVRC2012/ +├──train/ +│ ├── n01440764 +│ │ ├── n01440764_10026.JPEG +│ │ ├── n01440764_10027.JPEG +│ │ ├── ...... +│ ├── ...... +├──val/ +│ ├── n01440764 +│ │ ├── ILSVRC2012_val_00000293.JPEG +│ │ ├── ILSVRC2012_val_00002138.JPEG +│ │ ├── ...... +│ ├── ...... +``` + +### Download prertained model for training + +``` +sh download_pretrain.sh +``` + +### Training + + +**DeiT-S** + +``` +CUDA_VISIBLE_DEVICES="0,1,2,3" python3 -u -m torch.distributed.launch --nproc_per_node=4 --use_env main_l2_vit_3keep_senet.py + --output_dir logs/3keep_senet + --arch deit_small + --input-size 224 + --batch-size 256 + --data-path /data/ImageNet_new/ + --epochs 30 + --dist-eval + --distill + --base_rate 0.7 +``` +**DeiT-B** + +``` +CUDA_VISIBLE_DEVICES="0,1,2,3" python3 -u -m torch.distributed.launch --nproc_per_node=8 --use_env main_l2_vit_3keep_senet.py + --output_dir logs/deit_base_3keep_senet_256_60_5e-4 + --arch deit_base + --input-size 224 + --batch-size 256 + --data-path /data/ImageNet_new/ + --epochs 60 + --dist-eval + --distill + --base_rate 0.7 + +``` + + +### Some hyperparameter tunning results https://docs.google.com/spreadsheets/d/1k25sS_-mmQyIvpIrn32GUw3eRuYcCy0cN0OSOq0QGFI/edit?usp=sharing -## score 数据记录 +### score https://drive.google.com/drive/folders/1diICKopeYL7H84Wsr0Xxh30e9xh6RX2d?usp=sharing -## 可选择是否使用全精度训练。关闭 amp 功能。在 engine_l2.py 的 train_one_epoch() 和 evaluate() -![](fig/1.png) +### Full precision training +To turn off the amp function, go to ```engine_l2.py``` ```train_one_epoch()``` and ```evaluate()``` -## vit.py 文件改动,生成 vit_l2.py [对于 one keep token] +

+
+

-### 生成 multihead-predictor 类 -### VisionTransformerDiffPruning 的 forward() -## vit.py 文件改动,生成 vit_l2_3keep.py [对于 three keep tokens] +## Acknowledgements +Our code is based on [pytorch-image-models](https://github.com/rwightman/pytorch-image-models), [DeiT](https://github.com/facebookresearch/deit), [DynamicViT](https://github.com/raoyongming/DynamicViT). -### 生成 multihead-predictor 类 -### VisionTransformerDiffPruning 的 forward() +## Citation +If you find our work useful in your research, please consider citing: +``` +@inproceedings{kong2022spvit, + title={SPViT: Enabling Faster Vision Transformers via Latency-Aware Soft Token Pruning}, + author={Kong, Zhenglun and Dong, Peiyan and Ma, Xiaolong and Meng, Xin and Niu, Wei and Sun, Mengshu and Shen, Xuan and Yuan, Geng and Ren, Bin and Tang, Hao and others}, + booktitle={Computer Vision--ECCV 2022: 17th European Conference, Tel Aviv, Israel, October 23--27, 2022, Proceedings, Part XI}, + pages={620--640}, + year={2022}, + organization={Springer} +} +``` -###环境:因为加入了torch.nan_to_num(x, nan=4.0),需要用torch 1.8.0