Skip to content

Latest commit

 

History

History
78 lines (59 loc) · 1.69 KB

get_started.md

File metadata and controls

78 lines (59 loc) · 1.69 KB

STViT-R

This folder contains the implementation of the STViT-R for image classification.

Install

  • Clone this repo:
git clone https://github.com/changsn/STViT-R.git
cd STViT-R
  • Create a conda virtual environment and activate it:
conda create -n stvit-r python=3.7 -y
conda activate stvit-r
conda install pytorch==1.7.1 torchvision==0.8.2 cudatoolkit=10.1 -c pytorch
  • Install timm==0.3.2:
pip install timm==0.3.2
  • Install Apex:
git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
  • Install other requirements:
pip install opencv-python==4.4.0.46 termcolor==1.1.0 yacs==0.1.8

Data preparation

We use standard ImageNet dataset, you can download it from http://image-net.org/. We provide the following two ways to load data:

The file structure should look like:

$ tree data
imagenet
├── train
│   ├── class1
│   │   ├── img1.jpeg
│   │   ├── img2.jpeg
│   │   └── ...
│   ├── class2
│   │   ├── img3.jpeg
│   │   └── ...
│   └── ...
└── val
    ├── class1
    │   ├── img4.jpeg
    │   ├── img5.jpeg
    │   └── ...
    ├── class2
    │   ├── img6.jpeg
    │   └── ...
    └── ...