-
Notifications
You must be signed in to change notification settings - Fork 35
/
train.py
319 lines (257 loc) · 13.2 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
#import needed library
import os
import logging
import random
import warnings
import numpy as np
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.multiprocessing as mp
from utils import net_builder, get_logger, count_parameters
from train_utils import TBLog, get_SGD, get_cosine_schedule_with_warmup
from models.fixmatch.fixmatch import FixMatch
from datasets.ssl_dataset import SSL_Dataset
from datasets.data_utils import get_data_loader
def main(args):
'''
For (Distributed)DataParallelism,
main(args) spawn each process (main_worker) to each GPU.
'''
save_path = os.path.join(args.save_dir, args.save_name)
if os.path.exists(save_path) and not args.overwrite:
raise Exception('already existing model: {}'.format(save_path))
if args.resume:
if args.load_path is None:
raise Exception('Resume of training requires --load_path in the args')
if os.path.abspath(save_path) == os.path.abspath(args.load_path) and not args.overwrite:
raise Exception('Saving & Loading pathes are same. \
If you want over-write, give --overwrite in the argument.')
if args.seed is not None:
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
if args.gpu is not None:
warnings.warn('You have chosen a specific GPU. This will completely '
'disable data parallelism.')
if args.dist_url == "env://" and args.world_size == -1:
args.world_size = int(os.environ["WORLD_SIZE"])
#distributed: true if manually selected or if world_size > 1
args.distributed = args.world_size > 1 or args.multiprocessing_distributed
ngpus_per_node = torch.cuda.device_count() # number of gpus of each node
#divide the batch_size according to the number of nodes
args.batch_size = int(args.batch_size / args.world_size)
if args.multiprocessing_distributed:
# now, args.world_size means num of total processes in all nodes
args.world_size = ngpus_per_node * args.world_size
#args=(,) means the arguments of main_worker
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
else:
main_worker(args.gpu, ngpus_per_node, args)
def main_worker(gpu, ngpus_per_node, args):
'''
main_worker is conducted on each GPU.
'''
global best_acc1
args.gpu = gpu
# random seed has to be set for the syncronization of labeled data sampling in each process.
assert args.seed is not None
random.seed(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
cudnn.deterministic = True
# SET UP FOR DISTRIBUTED TRAINING
if args.distributed:
if args.dist_url == "env://" and args.rank == -1:
args.rank = int(os.environ["RANK"])
if args.multiprocessing_distributed:
args.rank = args.rank * ngpus_per_node + gpu # compute global rank
# set distributed group:
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
#SET save_path and logger
save_path = os.path.join(args.save_dir, args.save_name)
logger_level = "WARNING"
tb_log = None
if args.rank % ngpus_per_node == 0:
tb_log = TBLog(save_path, 'tensorboard')
logger_level = "INFO"
logger = get_logger(args.save_name, save_path, logger_level)
logger.warning(f"USE GPU: {args.gpu} for training")
# SET FixMatch: class FixMatch in models.fixmatch
args.bn_momentum = 1.0 - args.ema_m
_net_builder = net_builder(args.net,
args.net_from_name,
{'depth': args.depth,
'widen_factor': args.widen_factor,
'leaky_slope': args.leaky_slope,
'bn_momentum': args.bn_momentum,
'dropRate': args.dropout})
model = FixMatch(_net_builder,
args.num_classes,
args.ema_m,
args.T,
args.p_cutoff,
args.ulb_loss_ratio,
args.hard_label,
num_eval_iter=args.num_eval_iter,
tb_log=tb_log,
logger=logger)
logger.info(f'Number of Trainable Params: {count_parameters(model.train_model)}')
# SET Optimizer & LR Scheduler
## construct SGD and cosine lr scheduler
optimizer = get_SGD(model.train_model, 'SGD', args.lr, args.momentum, args.weight_decay)
scheduler = get_cosine_schedule_with_warmup(optimizer,
args.num_train_iter,
num_warmup_steps=args.num_train_iter*0)
## set SGD and cosine lr on FixMatch
model.set_optimizer(optimizer, scheduler)
# SET Devices for (Distributed) DataParallel
if not torch.cuda.is_available():
raise Exception('ONLY GPU TRAINING IS SUPPORTED')
elif args.distributed:
if args.gpu is not None:
torch.cuda.set_device(args.gpu)
'''
batch_size: batch_size per node -> batch_size per gpu
workers: workers per node -> workers per gpu
'''
args.batch_size = int(args.batch_size / ngpus_per_node)
model.train_model.cuda(args.gpu)
model.train_model = torch.nn.parallel.DistributedDataParallel(model.train_model,
device_ids=[args.gpu])
model.eval_model.cuda(args.gpu)
else:
# if arg.gpu is None, DDP will divide and allocate batch_size
# to all available GPUs if device_ids are not set.
model.cuda()
model = torch.nn.parallel.DistributedDataParallel(model)
elif args.gpu is not None:
torch.cuda.set_device(args.gpu)
model.train_model = model.train_model.cuda(args.gpu)
model.eval_model = model.eval_model.cuda(args.gpu)
else:
model.train_model = torch.nn.DataParallel(model.train_model).cuda()
model.eval_model = torch.nn.DataParallel(model.eval_model).cuda()
logger.info(f"model_arch: {model}")
logger.info(f"Arguments: {args}")
cudnn.benchmark = True
# Construct Dataset & DataLoader
train_dset = SSL_Dataset(name=args.dataset, train=True,
num_classes=args.num_classes, data_dir=args.data_dir)
lb_dset, ulb_dset = train_dset.get_ssl_dset(args.num_labels)
_eval_dset = SSL_Dataset(name=args.dataset, train=False,
num_classes=args.num_classes, data_dir=args.data_dir)
eval_dset = _eval_dset.get_dset()
loader_dict = {}
dset_dict = {'train_lb': lb_dset, 'train_ulb': ulb_dset, 'eval': eval_dset}
loader_dict['train_lb'] = get_data_loader(dset_dict['train_lb'],
args.batch_size,
data_sampler = args.train_sampler,
num_iters=args.num_train_iter,
num_workers=args.num_workers,
distributed=args.distributed)
loader_dict['train_ulb'] = get_data_loader(dset_dict['train_ulb'],
args.batch_size*args.uratio,
data_sampler = args.train_sampler,
num_iters=args.num_train_iter,
num_workers=4*args.num_workers,
distributed=args.distributed)
loader_dict['eval'] = get_data_loader(dset_dict['eval'],
args.eval_batch_size,
num_workers=args.num_workers)
## set DataLoader on FixMatch
model.set_data_loader(loader_dict)
#If args.resume, load checkpoints from args.load_path
if args.resume:
model.load_model(args.load_path)
# START TRAINING of FixMatch
trainer = model.train
for epoch in range(args.epoch):
trainer(args, logger=logger)
if not args.multiprocessing_distributed or \
(args.multiprocessing_distributed and args.rank % ngpus_per_node == 0):
model.save_model('latest_model.pth', save_path)
logging.warning(f"GPU {args.rank} training is FINISHED")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='')
'''
Saving & loading of the model.
'''
parser.add_argument('--save_dir', type=str, default='./saved_models')
parser.add_argument('--save_name', type=str, default='fixmatch')
parser.add_argument('--resume', action='store_true')
parser.add_argument('--load_path', type=str, default=None)
parser.add_argument('--overwrite', action='store_true')
'''
Training Configuration of FixMatch
'''
parser.add_argument('--epoch', type=int, default=1)
parser.add_argument('--num_train_iter', type=int, default=2**20,
help='total number of training iterations')
parser.add_argument('--num_eval_iter', type=int, default=10000,
help='evaluation frequency')
parser.add_argument('--num_labels', type=int, default=4000)
parser.add_argument('--batch_size', type=int, default=64,
help='total number of batch size of labeled data')
parser.add_argument('--uratio', type=int, default=7,
help='the ratio of unlabeled data to labeld data in each mini-batch')
parser.add_argument('--eval_batch_size', type=int, default=1024,
help='batch size of evaluation data loader (it does not affect the accuracy)')
parser.add_argument('--hard_label', type=bool, default=True)
parser.add_argument('--T', type=float, default=0.5)
parser.add_argument('--p_cutoff', type=float, default=0.95)
parser.add_argument('--ema_m', type=float, default=0.999, help='ema momentum for eval_model')
parser.add_argument('--ulb_loss_ratio', type=float, default=1.0)
'''
Optimizer configurations
'''
parser.add_argument('--lr', type=float, default=0.03)
parser.add_argument('--momentum', type=float, default=0.9)
parser.add_argument('--weight_decay', type=float, default=5e-4)
parser.add_argument('--amp', action='store_true', help='use mixed precision training or not')
'''
Backbone Net Configurations
'''
parser.add_argument('--net', type=str, default='WideResNet')
parser.add_argument('--net_from_name', type=bool, default=False)
parser.add_argument('--depth', type=int, default=28)
parser.add_argument('--widen_factor', type=int, default=2)
parser.add_argument('--leaky_slope', type=float, default=0.1)
parser.add_argument('--dropout', type=float, default=0.0)
'''
Data Configurations
'''
parser.add_argument('--data_dir', type=str, default='./data')
parser.add_argument('--dataset', type=str, default='cifar10')
parser.add_argument('--train_sampler', type=str, default='RandomSampler')
parser.add_argument('--num_classes', type=int, default=10)
parser.add_argument('--num_workers', type=int, default=1)
'''
multi-GPUs & Distrbitued Training
'''
## args for distributed training (from https://github.com/pytorch/examples/blob/master/imagenet/main.py)
parser.add_argument('--world-size', default=-1, type=int,
help='number of nodes for distributed training')
parser.add_argument('--rank', default=-1, type=int,
help='**node rank** for distributed training')
parser.add_argument('--dist-url', default='tcp://127.0.0.1:10001', type=str,
help='url used to set up distributed training')
parser.add_argument('--dist-backend', default='nccl', type=str,
help='distributed backend')
parser.add_argument('--seed', default=0, type=int,
help='seed for initializing training. ')
parser.add_argument('--gpu', default=None, type=int,
help='GPU id to use.')
parser.add_argument('--multiprocessing-distributed', action='store_true',
help='Use multi-processing distributed training to launch '
'N processes per node, which has N GPUs. This is the '
'fastest way to use PyTorch for either single node or '
'multi node data parallel training')
args = parser.parse_args()
main(args)