This repository holds companion code for our paper Jointly-learned Exit and Inference for Dynamic Neural Networks. We explore ways of training an efficient dynamic neural network by augmenting a frozen off-the-shelf neural network as backbone.
Out of the box this codebase supports CIFAR10, CIFAR100, SVHN and CIFAR100-LT.
Out of the box we support T2T-ViT-7 and T2T-ViT-14. You can add other models using timm.
First, install all requirements pip install -r requirements.txt
(it is preferable if you create a virtual environment using conda or venv and install those requirements in that environment)
You can then choose to run JEI-DNN directly on some of the already supported datasets or transfer learn an Imagenet-pretrained version onto a new dataset.
- Download the weights for the 7-layer vision transformer T2T-ViT-7 trained on Imagenet: from https://github.com/yitu-opensource/T2T-ViT and store them locally.
- To run the following, making sure you update the
--weights-path
parameter below:
python transfer_learning.py --lr 0.05 --b 64 --dataset svhn --weights-path model_weights/71.7_T2T_ViT_7.pth.tar
- Download the model checkpoints from this google drive and place each in the appropriate checkpoint folder (by dataset and architecture
checkpoint/checkpoint_DATASET_ARCH/
). - Make sure the checkpoint path matches the path in train_dynn.py for that dataset (For example, for SVHN).
- You can now train JEI-DNN for dynamic inference. You can specify the dataset, the architecture, number of epochs and the
ce_ic_tradeoff
. Higher values ofce_ic_tradeoff
mean the model is trained to exit earlier, at the cost of losing accuracy. A full list of arguments can be found insrc/train_dynn,py
python train_dynn.py --ce_ic_tradeoff 0.15 --dataset cifar10 --arch t2t_vit_7 --num_epoch 15;
We used mlflow to monitor our running scripts.
- Install mlflow (it is listed in our requirements.txt)
- To start the ui
mlflow ui
(from the root of the project)
The starting code was taken from the original repository for T2T-ViT since we used T2T-ViT-7/14 models as backbone for our dynamic neural network.