Sanity-checking pruning methods: Random tickets can win the jackpot Jingtong Su, Yihang Chen, Tianle Cai, Tianhao Wu, Ruiqi Gao, Liwei Wang,and Jason D. Lee.
Run from GraSP directory:
cd GraSP
Several examples:
- CIFAR-10, VGG19, Pruning ratio = 98%
python main.py --config configs/cifar10/vgg19/GraSP_98.json --seed 0
- CIFAR-10, VGG19, Pruning ratio = 98%, corrupted data
python main.py --config configs/cifar10/vgg19/GraSP_98_corrupt.json
- CIFAR-10, VGG19, Pruning ratio = 98%, rearranged layers
python main.py --config configs/cifar10/vgg19/GraSP_98_rearrange.json
Checkout experiments logs: GraSP/runs/pruning/cifar10/vgg19/...
Code is based on the initial GraSP repo
GraSP paper: Picking Winning Tickets Before Training by Preserving Gradient Flow
It is recommended to run SNIP code from its directory:
cd SNIP_and_partially_trained_tickets
In the given directory is a file main_snip.py
which does pruning according to SNIP and trains network afterwards based on used parameters. It can be run using the following command:
python main_snip.py --dataset <cifar10/cifar100> --architecture <vgg19/resnet32> --epochs [number of training epochs] --pruning_ratio [percent of weights to prune in range between 0 and 1 (inclusive)] --seed [value of seed] -sc [names of sanity checks divided by whitespace]
Supported sanity checks for SNIP are: random_labels, random_pixels, layerwise_rearrange.
- CIFAR-10, VGG19, Pruning ratio = 98%, training for 160 epochs after pruning
python main_snip.py --dataset cifar10 --architecture vgg19 --epochs 160 --pruning_ratio 0.98 --seed 2020
- CIFAR-10, VGG19, Pruning ratio = 98%, training for 160 epochs after pruning, corrupted data
python main_snip.py --dataset cifar10 --architecture vgg19 --epochs 160 --pruning_ratio 0.98 --seed 2020 -sc random_pixels random_labels
- CIFAR-10, VGG19, Pruning ratio = 98%, training for 160 epochs after pruning, layerwise rearrange
python main_snip.py --dataset cifar10 --architecture vgg19 --epochs 160 --pruning_ratio 0.98 --seed 2020 -sc layerwise_rearrange
Training logs for partially-trained tickets can be found in directory SNIP_and_partially_trained_tickets/training_logs/SNIP/<cifar10|cifar100>
CIFAR-10 | CIFAR-100 | |||||
---|---|---|---|---|---|---|
Network\Sparsity | 90% | 95% | 98% | 90% | 95% | 98% |
VGG19 | 93.65 (-0.12) | 93.22 (-0.20) | 92.41 (-0.04) | 72.75 (0.20) | 72.00 (0.63) | 68.77 (-0.21) |
ResNet32 | 92.84 (-0.13) | 91.64 (0.04) | 88.68 (-0.42) | 68.50 (-1.20) | 65.99 (-0.83) | 59.67 (-0.44) |
To reproduce results for Random Tickets move to the cd random_tickets
directory and run the following command:
python train.py --dataset <cifar10/cifar100> --network <vgg/resnet> --pruning random --ratio <0.9/0.95/0.98> --wd 0.0005 --seed 0
For example, for VGG19 on CIFAR10 with 90% sparsity run the following command:
python train.py --dataset cifar10 --network vgg --pruning random --ratio 0.9 --wd 0.0005 --seed 0
You can then find the logs at ./logs/myexman-train.py/runs/<id>/logs.csv
It is recommended to run code for partially-trained tickets from its directory:
cd SNIP_and_partially_trained_tickets
In the given directory is a file main_partially_trained_tickets.py
which supports partially-trained ticket pruning methods used in the original paper and their corresponding sanity checks. It can be run using the following command:
python main_partially_trained_tickets.py --dataset <cifar10/cifar100> --architecture <vgg19/resnet32> --epochs [number of training epochs] --fine_tuning_epochs [number of fine-tuning epochs] --pruning_ratio [percent of weights to prune in range between 0 and 1 (inclusive)] --seed [value of seed] --rewinding_type <weights/learning_rate> --rewind_epoch [number of epoch to which scheduler or weights should be rewinded after pruning] -sc [names of sanity checks divided by whitespace]
The given command also accepts an optinal argument --hybrid_tickets
which specifies that hybrid tickets method should be used while pruning weights. According to the original paper learning_rate
should be used as a rewinding_type
with hybrid tickets, but the code will also work with weights
as rewinding_type
.
Supported sanity checks for partially-trained tickets are: half_dataset, layerwise_weights_shuffling.
All the results for partially-trained tickets are obtained using seed 2020.
- CIFAR-10, VGG19, Pre-trained for 160 epochs on half dataset, pruned 98% of weights, rewinded to the 40th epoch using weight rewinding and retrained for additional 160 epochs
python main_partially_trained_tickets.py --dataset cifar10 --architecture vgg19 --epochs 160 --rewind_epoch 40 --fine_tuning_epochs 160 --pruning_ratio 0.98 --seed 2020 --rewinding_type weights -sc half_dataset
- CIFAR-100, ResNet32, Pre-trained for 160 epochs, pruned using hybrid tickets 95% of weights and rewinded to the 40th epoch using learning rate rewinding and retrained for additional 160 epochs
python main_partially_trained_tickets.py --dataset cifar100 --architecture resnet32 --epochs 160 --rewind_epoch 40 --fine_tuning_epochs 160 --pruning_ratio 0.95 --hybrid_tickets --seed 2020 --rewinding_type learning_rate
- CIFAR-100, ResNet32, Pre-trained for 160 epochs, pruned 90% of weights, performed layerwise weight shuffling and rewinded to the 40th epoch using learning rate rewinding and retrained for additional 160 epochs
python main_partially_trained_tickets.py --dataset cifar100 --architecture resnet32 --epochs 160 --rewind_epoch 40 --fine_tuning_epochs 160 --pruning_ratio 0.9 --seed 2020 --rewinding_type learning_rate -sc layerwise_weights_shuffling
Training logs for partially-trained tickets can be found in directory SNIP_and_partially_trained_tickets/training_logs/partially_trained_tickets/<hybrid_tickets|learning_rate_rewinding>
- Andrei Atanov: [email protected]
- Valentina Shumovskaia: [email protected]
- Miloš Vujasinović: [email protected]