-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
97 additions
and
30 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,50 +1,117 @@ | ||
## 环境配置: | ||
|
||
conda create -n DynamicVit python=3.6 | ||
|
||
conda activate DynamicVit | ||
|
||
conda deactivate | ||
|
||
conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 -c pytorch | ||
|
||
pip3 install timm==0.4.5 | ||
# SPViT: Enabling Faster Vision Transformers via Latency-aware Soft Token Pruning | ||
|
||
This repository contains PyTorch training code for the ECCV 2022 paper. | ||
|
||
## 下载prertained model | ||
见 download_pretrain.sh | ||
[arXiv](https://arxiv.org/abs/2211.10801) | [ECCV 2022](https://www.ecva.net/papers/eccv_2022/papers_ECCV/papers/136710618.pdf) | ||
|
||
<p align="center"> | ||
<img src="./plot_time_acc_table.png" width=60%> <br> | ||
Comparison of different models with various accuracy-training time trade-off.. | ||
</p> | ||
|
||
## Usage | ||
|
||
## 命令 | ||
### Environment Settings | ||
|
||
举例:跑deit-small, 用3keep+senet的代码. | ||
|
||
python3 -u -m torch.distributed.launch --nproc_per_node=8 --use_env main_l2_vit_3keep_senet.py --output_dir logs/3keep_senet --arch deit_small --input-size 224 --batch-size 256 --data-path /data/ImageNet_new/ --epochs 30 --dist-eval --distill --base_rate 0.7 2>&1 | tee -i 3keep_senet.log | ||
``` | ||
conda create -n DynamicVit python=3.6 | ||
跑deit-base | ||
conda activate DynamicVit | ||
python3 -u -m torch.distributed.launch --nproc_per_node=8 --use_env main_l2_vit_3keep_senet.py --output_dir logs/deit_base_3keep_senet_256_60_5e-4 --arch deit_base --input-size 224 --batch-size 256 --data-path /data/ImageNet_new/ --epochs 60 --dist-eval --distill --base_rate 0.7 2>&1 | tee -i deit_base_3keep_senet_256_60_5e-4.log | ||
conda deactivate | ||
conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 -c pytorch | ||
## 调参 link | ||
pip3 install timm==0.4.5 | ||
``` | ||
|
||
### Data preparation | ||
download and extract ImageNet images from http://image-net.org/. The directory structure should be | ||
|
||
``` | ||
│ILSVRC2012/ | ||
├──train/ | ||
│ ├── n01440764 | ||
│ │ ├── n01440764_10026.JPEG | ||
│ │ ├── n01440764_10027.JPEG | ||
│ │ ├── ...... | ||
│ ├── ...... | ||
├──val/ | ||
│ ├── n01440764 | ||
│ │ ├── ILSVRC2012_val_00000293.JPEG | ||
│ │ ├── ILSVRC2012_val_00002138.JPEG | ||
│ │ ├── ...... | ||
│ ├── ...... | ||
``` | ||
|
||
### Download prertained model for training | ||
|
||
``` | ||
sh download_pretrain.sh | ||
``` | ||
|
||
### Training | ||
|
||
|
||
**DeiT-S** | ||
|
||
``` | ||
CUDA_VISIBLE_DEVICES="0,1,2,3" python3 -u -m torch.distributed.launch --nproc_per_node=4 --use_env main_l2_vit_3keep_senet.py | ||
--output_dir logs/3keep_senet | ||
--arch deit_small | ||
--input-size 224 | ||
--batch-size 256 | ||
--data-path /data/ImageNet_new/ | ||
--epochs 30 | ||
--dist-eval | ||
--distill | ||
--base_rate 0.7 | ||
``` | ||
**DeiT-B** | ||
|
||
``` | ||
CUDA_VISIBLE_DEVICES="0,1,2,3" python3 -u -m torch.distributed.launch --nproc_per_node=8 --use_env main_l2_vit_3keep_senet.py | ||
--output_dir logs/deit_base_3keep_senet_256_60_5e-4 | ||
--arch deit_base | ||
--input-size 224 | ||
--batch-size 256 | ||
--data-path /data/ImageNet_new/ | ||
--epochs 60 | ||
--dist-eval | ||
--distill | ||
--base_rate 0.7 | ||
``` | ||
|
||
|
||
### Some hyperparameter tunning results | ||
https://docs.google.com/spreadsheets/d/1k25sS_-mmQyIvpIrn32GUw3eRuYcCy0cN0OSOq0QGFI/edit?usp=sharing | ||
|
||
## score 数据记录 | ||
### score | ||
https://drive.google.com/drive/folders/1diICKopeYL7H84Wsr0Xxh30e9xh6RX2d?usp=sharing | ||
|
||
## 可选择是否使用全精度训练。关闭 amp 功能。在 engine_l2.py 的 train_one_epoch() 和 evaluate() | ||
![](fig/1.png) | ||
### Full precision training | ||
To turn off the amp function, go to ```engine_l2.py``` ```train_one_epoch()``` and ```evaluate()``` | ||
|
||
## vit.py 文件改动,生成 vit_l2.py [对于 one keep token] | ||
<p align="center"> | ||
<img src="fig/1.png" width=60%> <br> | ||
</p> | ||
|
||
### 生成 multihead-predictor 类 | ||
### VisionTransformerDiffPruning 的 forward() | ||
|
||
## vit.py 文件改动,生成 vit_l2_3keep.py [对于 three keep tokens] | ||
## Acknowledgements | ||
|
||
Our code is based on [pytorch-image-models](https://github.com/rwightman/pytorch-image-models), [DeiT](https://github.com/facebookresearch/deit), [DynamicViT](https://github.com/raoyongming/DynamicViT). | ||
|
||
### 生成 multihead-predictor 类 | ||
### VisionTransformerDiffPruning 的 forward() | ||
## Citation | ||
If you find our work useful in your research, please consider citing: | ||
``` | ||
@inproceedings{kong2022spvit, | ||
title={SPViT: Enabling Faster Vision Transformers via Latency-Aware Soft Token Pruning}, | ||
author={Kong, Zhenglun and Dong, Peiyan and Ma, Xiaolong and Meng, Xin and Niu, Wei and Sun, Mengshu and Shen, Xuan and Yuan, Geng and Ren, Bin and Tang, Hao and others}, | ||
booktitle={Computer Vision--ECCV 2022: 17th European Conference, Tel Aviv, Israel, October 23--27, 2022, Proceedings, Part XI}, | ||
pages={620--640}, | ||
year={2022}, | ||
organization={Springer} | ||
} | ||
``` | ||
|
||
###环境:因为加入了torch.nan_to_num(x, nan=4.0),需要用torch 1.8.0 |