This folder contains the implementation of the STViT-R for image classification.
- Clone this repo:
git clone https://github.com/changsn/STViT-R.git
cd STViT-R
- Create a conda virtual environment and activate it:
conda create -n stvit-r python=3.7 -y
conda activate stvit-r
- Install
CUDA==10.1
withcudnn7
following the official installation instructions - Install
PyTorch==1.7.1
andtorchvision==0.8.2
withCUDA==10.1
:
conda install pytorch==1.7.1 torchvision==0.8.2 cudatoolkit=10.1 -c pytorch
- Install
timm==0.3.2
:
pip install timm==0.3.2
- Install
Apex
:
git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
- Install other requirements:
pip install opencv-python==4.4.0.46 termcolor==1.1.0 yacs==0.1.8
We use standard ImageNet dataset, you can download it from http://image-net.org/. We provide the following two ways to load data:
The file structure should look like:
$ tree data
imagenet
├── train
│ ├── class1
│ │ ├── img1.jpeg
│ │ ├── img2.jpeg
│ │ └── ...
│ ├── class2
│ │ ├── img3.jpeg
│ │ └── ...
│ └── ...
└── val
├── class1
│ ├── img4.jpeg
│ ├── img5.jpeg
│ └── ...
├── class2
│ ├── img6.jpeg
│ └── ...
└── ...