-
Notifications
You must be signed in to change notification settings - Fork 3
/
main.py
165 lines (153 loc) · 10.1 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import argparse
import json
import os
import random
from Experiments import singleshot
from Experiments import multishot
from Experiments import multishot_ddp_prune
from Experiments import multishot_ddp_train
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Network Compression')
# Training Hyperparameters
training_args = parser.add_argument_group('training')
training_args.add_argument('--dataset', type=str, default='mnist',
choices=['mnist','cifar10','cifar100','tiny-imagenet','imagenet'],
help='dataset (default: mnist)')
training_args.add_argument('--model', type=str, default='fc', choices=['fc','conv',
'vgg11','vgg11-bn','vgg13','vgg13-bn','vgg16','vgg16-bn','vgg19','vgg19-bn',
'resnet18','resnet20','resnet32','resnet34','resnet44','resnet50',
'resnet56','resnet101','resnet110','resnet110','resnet152','resnet1202',
'wide-resnet18','wide-resnet20','wide-resnet32','wide-resnet34','wide-resnet44','wide-resnet50',
'wide-resnet56','wide-resnet101','wide-resnet110','wide-resnet110','wide-resnet152','wide-resnet1202'],
help='model architecture (default: fc)')
training_args.add_argument('--model-class', type=str, default='default', choices=['default','lottery','tinyimagenet','imagenet'],
help='model class (default: default)')
training_args.add_argument('--dense-classifier', type=bool, default=False,
help='ensure last layer of model is dense (default: False)')
training_args.add_argument('--pretrained', type=bool, default=False,
help='load pretrained weights (default: False)')
training_args.add_argument('--optimizer', type=str, default='adam', choices=['sgd','momentum','adam','rms'],
help='optimizer (default: adam)')
training_args.add_argument('--train-batch-size', type=int, default=64,
help='input batch size for training (default: 64)')
training_args.add_argument('--test-batch-size', type=int, default=256,
help='input batch size for testing (default: 256)')
training_args.add_argument('--pre-epochs', type=int, default=0,
help='number of epochs to train before pruning (default: 0)')
training_args.add_argument('--post-epochs', type=int, default=10,
help='number of epochs to train after pruning (default: 10)')
training_args.add_argument('--lr', type=float, default=0.001,
help='learning rate (default: 0.001)')
training_args.add_argument('--lr-drops', type=int, nargs='*', default=[],
help='list of learning rate drops (default: [])')
training_args.add_argument('--lr-drop-rate', type=float, default=0.1,
help='multiplicative factor of learning rate drop (default: 0.1)')
training_args.add_argument('--weight-decay', type=float, default=0.0,
help='weight decay (default: 0.0)')
training_args.add_argument('--resume', default='', type=str, help='Checkpoint path for resume training.')
training_args.add_argument('--saved_best_test_epoch', default=0, type=int, help='Saved epoch which has best top1 test acc.')
training_args.add_argument('--saved_best_test_acc1', default=0.0, type=float, help='Saved best top1 test acc.')
training_args.add_argument('--start_epoch', default=0, type=int, help='Epoch which training starts.')
# Pruning Hyperparameters
pruning_args = parser.add_argument_group('pruning')
pruning_args.add_argument('--pruner', type=str, default='rand',
choices=['rand','mag','snip','grasp','synflow','NTKSAP', 'itersnip'],
help='prune strategy (default: rand)')
pruning_args.add_argument('--compression', type=float, default=0.0,
help='Target density, which is computed as the 0.8**compression (default: 0.0)')
pruning_args.add_argument('--prune-epochs', type=int, default=1,
help='number of iterations for scoring (default: 1)')
pruning_args.add_argument('--compression-schedule', type=str, default='exponential', choices=['linear','exponential', 'expinv'],
help='whether to use a linear or exponential compression schedule (default: exponential)')
pruning_args.add_argument('--mask-scope', type=str, default='global', choices=['global','local'],
help='masking scope (global or layer) (default: global)')
pruning_args.add_argument('--prune-dataset-ratio', type=int, default=10,
help='ratio of prune dataset size and number of classes (default: 10)')
pruning_args.add_argument('--prune-batch-size', type=int, default=256,
help='input batch size for pruning (default: 256)')
pruning_args.add_argument('--prune-bias', type=bool, default=False,
help='whether to prune bias parameters (default: False)')
pruning_args.add_argument('--prune-batchnorm', type=bool, default=False,
help='whether to prune batchnorm layers (default: False)')
pruning_args.add_argument('--prune-residual', type=bool, default=False,
help='whether to prune residual connections (default: False)')
pruning_args.add_argument('--prune-train-mode', type=bool, default=False,
help='whether to prune in train mode (default: False)')
pruning_args.add_argument('--reinitialize', type=bool, default=False,
help='whether to reinitialize weight parameters after pruning (default: False)')
pruning_args.add_argument('--shuffle', type=bool, default=False,
help='whether to shuffle masks after pruning (default: False)')
pruning_args.add_argument('--invert', type=bool, default=False,
help='whether to invert scores during pruning (default: False)')
pruning_args.add_argument('--pruner-list', type=str, nargs='*', default=[],
help='list of pruning strategies for singleshot (default: [])')
pruning_args.add_argument('--prune-epoch-list', type=int, nargs='*', default=[],
help='list of prune epochs for singleshot (default: [])')
pruning_args.add_argument('--compression-list', type=float, nargs='*', default=[],
help='A list of target density ratios, which is computed as the 0.8**compression (default: [])')
pruning_args.add_argument('--level-list', type=int, nargs='*', default=[],
help='list of number of prune-train cycles (levels) for multishot (default: [])')
pruning_args.add_argument('--ntksap_R', type=int, default=1,
help='Number of sampling rounds for initial weight configurations (default: 1)')
pruning_args.add_argument('--ntksap_epsilon', type=float, default=0.01,
help='purtubation hyper-paramemter used in NTK-SAP (default: 0.01)')
## Experiment Hyperparameters ##
parser.add_argument('--experiment', type=str, default='singleshot',
choices=['singleshot','multishot','multishot_ddp_prune', 'multishot_ddp_train'],
help='experiment name, ddp is for multi-gpu distributed data parallel training (default: example)')
parser.add_argument('--expid', type=str, default='',
help='name used to save results (default: "")')
parser.add_argument('--result-dir', type=str, default='Results/data',
help='path to directory to save results (default: "Results/data")')
parser.add_argument('--gpu', type=str, default='all',
help='number of GPU device to use (default: all GPUs)')
parser.add_argument('--workers', type=int, default='4',
help='number of data loading workers (default: 4)')
parser.add_argument('--no-cuda', action='store_true',
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=-1,
help='random seed (default: -1)')
parser.add_argument('--verbose', action='store_true',
help='print statistics during training and testing')
## For data distributed parallel training
parser.add_argument('--dist-url', default='tcp://127.0.0.1:23456', type=str,
help='url used to set up distributed training. This should be'
'the IP address and open port number of the master node')
parser.add_argument('--dist-backend', default='nccl', type=str,
help='distributed backend')
parser.add_argument('--world-size', default=1, type=int,
help='number of nodes for distributed training')
parser.add_argument('--rank', default=0, type=int,
help='node rank for distributed training')
args = parser.parse_args()
# Construct random seed
if args.seed is None or args.seed < 0:
args.seed = random.randint(1, 100000)
## Construct Result Directory ##
if args.expid == "":
print("WARNING: this experiment is not being saved.")
setattr(args, 'save', False)
else:
result_dir = '{}/{}/{}'.format(args.result_dir, args.experiment, args.expid)
setattr(args, 'save', True)
setattr(args, 'result_dir', result_dir)
try:
os.makedirs(result_dir)
except:
print('Overwriting existing file.')
print('All arguments: \n')
for arg_name in vars(args):
print('{}: {}'.format(arg_name, getattr(args, arg_name)))
## Save Args ##
if args.save:
with open(args.result_dir + '/args.json', 'w') as f:
json.dump(args.__dict__, f, sort_keys=True, indent=4)
## Run Experiment ##
if args.experiment == 'singleshot':
singleshot.run(args)
if args.experiment == 'multishot':
multishot.run(args)
if args.experiment == 'multishot_ddp_prune':
multishot_ddp_prune.run(args)
if args.experiment == 'multishot_ddp_train':
multishot_ddp_train.run(args)