Skip to content

Commit

Permalink
Update readme.md
Browse files Browse the repository at this point in the history
  • Loading branch information
ZLKong authored Apr 19, 2023
1 parent b44d5d8 commit 8a75093
Showing 1 changed file with 97 additions and 30 deletions.
127 changes: 97 additions & 30 deletions readme.md
Original file line number Diff line number Diff line change
@@ -1,50 +1,117 @@
## 环境配置:

conda create -n DynamicVit python=3.6

conda activate DynamicVit

conda deactivate

conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 -c pytorch

pip3 install timm==0.4.5
# SPViT: Enabling Faster Vision Transformers via Latency-aware Soft Token Pruning

This repository contains PyTorch training code for the ECCV 2022 paper.

## 下载prertained model
见 download_pretrain.sh
[arXiv](https://arxiv.org/abs/2211.10801) | [ECCV 2022](https://www.ecva.net/papers/eccv_2022/papers_ECCV/papers/136710618.pdf)

<p align="center">
<img src="./plot_time_acc_table.png" width=60%> <br>
Comparison of different models with various accuracy-training time trade-off..
</p>

## Usage

## 命令
### Environment Settings

举例:跑deit-small, 用3keep+senet的代码.

python3 -u -m torch.distributed.launch --nproc_per_node=8 --use_env main_l2_vit_3keep_senet.py --output_dir logs/3keep_senet --arch deit_small --input-size 224 --batch-size 256 --data-path /data/ImageNet_new/ --epochs 30 --dist-eval --distill --base_rate 0.7 2>&1 | tee -i 3keep_senet.log
```
conda create -n DynamicVit python=3.6
跑deit-base
conda activate DynamicVit
python3 -u -m torch.distributed.launch --nproc_per_node=8 --use_env main_l2_vit_3keep_senet.py --output_dir logs/deit_base_3keep_senet_256_60_5e-4 --arch deit_base --input-size 224 --batch-size 256 --data-path /data/ImageNet_new/ --epochs 60 --dist-eval --distill --base_rate 0.7 2>&1 | tee -i deit_base_3keep_senet_256_60_5e-4.log
conda deactivate
conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 -c pytorch
## 调参 link
pip3 install timm==0.4.5
```

### Data preparation
download and extract ImageNet images from http://image-net.org/. The directory structure should be

```
│ILSVRC2012/
├──train/
│ ├── n01440764
│ │ ├── n01440764_10026.JPEG
│ │ ├── n01440764_10027.JPEG
│ │ ├── ......
│ ├── ......
├──val/
│ ├── n01440764
│ │ ├── ILSVRC2012_val_00000293.JPEG
│ │ ├── ILSVRC2012_val_00002138.JPEG
│ │ ├── ......
│ ├── ......
```

### Download prertained model for training

```
sh download_pretrain.sh
```

### Training


**DeiT-S**

```
CUDA_VISIBLE_DEVICES="0,1,2,3" python3 -u -m torch.distributed.launch --nproc_per_node=4 --use_env main_l2_vit_3keep_senet.py
--output_dir logs/3keep_senet
--arch deit_small
--input-size 224
--batch-size 256
--data-path /data/ImageNet_new/
--epochs 30
--dist-eval
--distill
--base_rate 0.7
```
**DeiT-B**

```
CUDA_VISIBLE_DEVICES="0,1,2,3" python3 -u -m torch.distributed.launch --nproc_per_node=8 --use_env main_l2_vit_3keep_senet.py
--output_dir logs/deit_base_3keep_senet_256_60_5e-4
--arch deit_base
--input-size 224
--batch-size 256
--data-path /data/ImageNet_new/
--epochs 60
--dist-eval
--distill
--base_rate 0.7
```


### Some hyperparameter tunning results
https://docs.google.com/spreadsheets/d/1k25sS_-mmQyIvpIrn32GUw3eRuYcCy0cN0OSOq0QGFI/edit?usp=sharing

## score 数据记录
### score
https://drive.google.com/drive/folders/1diICKopeYL7H84Wsr0Xxh30e9xh6RX2d?usp=sharing

## 可选择是否使用全精度训练。关闭 amp 功能。在 engine_l2.py 的 train_one_epoch() 和 evaluate()
![](fig/1.png)
### Full precision training
To turn off the amp function, go to ```engine_l2.py``` ```train_one_epoch()``` and ```evaluate()```

## vit.py 文件改动,生成 vit_l2.py [对于 one keep token]
<p align="center">
<img src="fig/1.png" width=60%> <br>
</p>

### 生成 multihead-predictor 类
### VisionTransformerDiffPruning 的 forward()

## vit.py 文件改动,生成 vit_l2_3keep.py [对于 three keep tokens]
## Acknowledgements

Our code is based on [pytorch-image-models](https://github.com/rwightman/pytorch-image-models), [DeiT](https://github.com/facebookresearch/deit), [DynamicViT](https://github.com/raoyongming/DynamicViT).

### 生成 multihead-predictor 类
### VisionTransformerDiffPruning 的 forward()
## Citation
If you find our work useful in your research, please consider citing:
```
@inproceedings{kong2022spvit,
title={SPViT: Enabling Faster Vision Transformers via Latency-Aware Soft Token Pruning},
author={Kong, Zhenglun and Dong, Peiyan and Ma, Xiaolong and Meng, Xin and Niu, Wei and Sun, Mengshu and Shen, Xuan and Yuan, Geng and Ren, Bin and Tang, Hao and others},
booktitle={Computer Vision--ECCV 2022: 17th European Conference, Tel Aviv, Israel, October 23--27, 2022, Proceedings, Part XI},
pages={620--640},
year={2022},
organization={Springer}
}
```

###环境:因为加入了torch.nan_to_num(x, nan=4.0),需要用torch 1.8.0

0 comments on commit 8a75093

Please sign in to comment.