-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathtrain.py
691 lines (551 loc) · 27 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
import os
import sys
import time
import visdom
import argparse
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from torchvision import datasets, transforms
from torch.utils.data.dataset import Subset
from torch.optim.lr_scheduler import CosineAnnealingLR
from models.controller import Controller
from models.shared_cnn import SharedCNN
from utils.utils import AverageMeter, Logger
from utils.cutout import Cutout
parser = argparse.ArgumentParser(description='ENAS')
parser.add_argument('--search_for', default='macro', choices=['macro'])
parser.add_argument('--data_path', default='/export/mlrg/terrance/Projects/data/', type=str)
parser.add_argument('--output_filename', default='ENAS', type=str)
parser.add_argument('--resume', default='', type=str)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--num_epochs', type=int, default=310)
parser.add_argument('--log_every', type=int, default=50)
parser.add_argument('--eval_every_epochs', type=int, default=1)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--cutout', type=int, default=0)
parser.add_argument('--fixed_arc', action='store_true', default=False)
parser.add_argument('--child_num_layers', type=int, default=12)
parser.add_argument('--child_out_filters', type=int, default=36)
parser.add_argument('--child_grad_bound', type=float, default=5.0)
parser.add_argument('--child_l2_reg', type=float, default=0.00025)
parser.add_argument('--child_num_branches', type=int, default=6)
parser.add_argument('--child_keep_prob', type=float, default=0.9)
parser.add_argument('--child_lr_max', type=float, default=0.05)
parser.add_argument('--child_lr_min', type=float, default=0.0005)
parser.add_argument('--child_lr_T', type=float, default=10)
parser.add_argument('--controller_lstm_size', type=int, default=64)
parser.add_argument('--controller_lstm_num_layers', type=int, default=1)
parser.add_argument('--controller_entropy_weight', type=float, default=0.0001)
parser.add_argument('--controller_train_every', type=int, default=1)
parser.add_argument('--controller_num_aggregate', type=int, default=20)
parser.add_argument('--controller_train_steps', type=int, default=50)
parser.add_argument('--controller_lr', type=float, default=0.001)
parser.add_argument('--controller_tanh_constant', type=float, default=1.5)
parser.add_argument('--controller_op_tanh_reduce', type=float, default=2.5)
parser.add_argument('--controller_skip_target', type=float, default=0.4)
parser.add_argument('--controller_skip_weight', type=float, default=0.8)
parser.add_argument('--controller_bl_dec', type=float, default=0.99)
args = parser.parse_args()
vis = visdom.Visdom()
vis.env = 'ENAS_' + args.output_filename
vis_win = {'shared_cnn_acc': None, 'shared_cnn_loss': None, 'controller_reward': None,
'controller_acc': None, 'controller_loss': None}
def load_datasets():
"""Create data loaders for the CIFAR-10 dataset.
Returns: Dict containing data loaders.
"""
normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize])
if args.cutout > 0:
train_transform.transforms.append(Cutout(length=args.cutout))
valid_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize])
test_transform = transforms.Compose([
transforms.ToTensor(),
normalize])
train_dataset = datasets.CIFAR10(root=args.data_path,
train=True,
transform=train_transform,
download=True)
valid_dataset = datasets.CIFAR10(root=args.data_path,
train=True,
transform=valid_transform,
download=True)
test_dataset = datasets.CIFAR10(root=args.data_path,
train=False,
transform=test_transform,
download=True)
train_indices = list(range(0, 45000))
valid_indices = list(range(45000, 50000))
train_subset = Subset(train_dataset, train_indices)
valid_subset = Subset(valid_dataset, valid_indices)
data_loaders = {}
data_loaders['train_subset'] = torch.utils.data.DataLoader(dataset=train_subset,
batch_size=args.batch_size,
shuffle=True,
pin_memory=True,
num_workers=2)
data_loaders['valid_subset'] = torch.utils.data.DataLoader(dataset=valid_subset,
batch_size=args.batch_size,
shuffle=True,
pin_memory=True,
num_workers=2,
drop_last=True)
data_loaders['train_dataset'] = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=args.batch_size,
shuffle=True,
pin_memory=True,
num_workers=2)
data_loaders['test_dataset'] = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=args.batch_size,
shuffle=False,
pin_memory=True,
num_workers=2)
return data_loaders
def train_shared_cnn(epoch,
controller,
shared_cnn,
data_loaders,
shared_cnn_optimizer,
fixed_arc=None):
"""Train shared_cnn by sampling architectures from the controller.
Args:
epoch: Current epoch.
controller: Controller module that generates architectures to be trained.
shared_cnn: CNN that contains all possible architectures, with shared weights.
data_loaders: Dict containing data loaders.
shared_cnn_optimizer: Optimizer for the shared_cnn.
fixed_arc: Architecture to train, overrides the controller sample
...
Returns: Nothing.
"""
global vis_win
controller.eval()
if fixed_arc is None:
# Use a subset of the training set when searching for an arhcitecture
train_loader = data_loaders['train_subset']
else:
# Use the full training set when training a fixed architecture
train_loader = data_loaders['train_dataset']
train_acc_meter = AverageMeter()
loss_meter = AverageMeter()
for i, (images, labels) in enumerate(train_loader):
start = time.time()
images = images.cuda()
labels = labels.cuda()
if fixed_arc is None:
with torch.no_grad():
controller() # perform forward pass to generate a new architecture
sample_arc = controller.sample_arc
else:
sample_arc = fixed_arc
shared_cnn.zero_grad()
pred = shared_cnn(images, sample_arc)
loss = nn.CrossEntropyLoss()(pred, labels)
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(shared_cnn.parameters(), args.child_grad_bound)
shared_cnn_optimizer.step()
train_acc = torch.mean((torch.max(pred, 1)[1] == labels).type(torch.float))
train_acc_meter.update(train_acc.item())
loss_meter.update(loss.item())
end = time.time()
if (i) % args.log_every == 0:
learning_rate = shared_cnn_optimizer.param_groups[0]['lr']
display = 'epoch=' + str(epoch) + \
'\tch_step=' + str(i) + \
'\tloss=%.6f' % (loss_meter.val) + \
'\tlr=%.4f' % (learning_rate) + \
'\t|g|=%.4f' % (grad_norm.item()) + \
'\tacc=%.4f' % (train_acc_meter.val) + \
'\ttime=%.2fit/s' % (1. / (end - start))
print(display)
vis_win['shared_cnn_acc'] = vis.line(
X=np.array([epoch]),
Y=np.array([train_acc_meter.avg]),
win=vis_win['shared_cnn_acc'],
opts=dict(title='shared_cnn_acc', xlabel='Iteration', ylabel='Accuracy'),
update='append' if epoch > 0 else None)
vis_win['shared_cnn_loss'] = vis.line(
X=np.array([epoch]),
Y=np.array([loss_meter.avg]),
win=vis_win['shared_cnn_loss'],
opts=dict(title='shared_cnn_loss', xlabel='Iteration', ylabel='Loss'),
update='append' if epoch > 0 else None)
controller.train()
def train_controller(epoch,
controller,
shared_cnn,
data_loaders,
controller_optimizer,
baseline=None):
"""Train controller to optimizer validation accuracy using REINFORCE.
Args:
epoch: Current epoch.
controller: Controller module that generates architectures to be trained.
shared_cnn: CNN that contains all possible architectures, with shared weights.
data_loaders: Dict containing data loaders.
controller_optimizer: Optimizer for the controller.
baseline: The baseline score (i.e. average val_acc) from the previous epoch
Returns:
baseline: The baseline score (i.e. average val_acc) for the current epoch
For more stable training we perform weight updates using the average of
many gradient estimates. controller_num_aggregate indicates how many samples
we want to average over (default = 20). By default PyTorch will sum gradients
each time .backward() is called (as long as an optimizer step is not taken),
so each iteration we divide the loss by controller_num_aggregate to get the
average.
https://github.com/melodyguan/enas/blob/master/src/cifar10/general_controller.py#L270
"""
print('Epoch ' + str(epoch) + ': Training controller')
global vis_win
shared_cnn.eval()
valid_loader = data_loaders['valid_subset']
reward_meter = AverageMeter()
baseline_meter = AverageMeter()
val_acc_meter = AverageMeter()
loss_meter = AverageMeter()
controller.zero_grad()
for i in range(args.controller_train_steps * args.controller_num_aggregate):
start = time.time()
images, labels = next(iter(valid_loader))
images = images.cuda()
labels = labels.cuda()
controller() # perform forward pass to generate a new architecture
sample_arc = controller.sample_arc
with torch.no_grad():
pred = shared_cnn(images, sample_arc)
val_acc = torch.mean((torch.max(pred, 1)[1] == labels).type(torch.float))
# detach to make sure that gradients aren't backpropped through the reward
reward = torch.tensor(val_acc.detach())
reward += args.controller_entropy_weight * controller.sample_entropy
if baseline is None:
baseline = val_acc
else:
baseline -= (1 - args.controller_bl_dec) * (baseline - reward)
# detach to make sure that gradients are not backpropped through the baseline
baseline = baseline.detach()
loss = -1 * controller.sample_log_prob * (reward - baseline)
if args.controller_skip_weight is not None:
loss += args.controller_skip_weight * controller.skip_penaltys
reward_meter.update(reward.item())
baseline_meter.update(baseline.item())
val_acc_meter.update(val_acc.item())
loss_meter.update(loss.item())
# Average gradient over controller_num_aggregate samples
loss = loss / args.controller_num_aggregate
loss.backward(retain_graph=True)
end = time.time()
# Aggregate gradients for controller_num_aggregate iterationa, then update weights
if (i + 1) % args.controller_num_aggregate == 0:
grad_norm = torch.nn.utils.clip_grad_norm_(controller.parameters(), args.child_grad_bound)
controller_optimizer.step()
controller.zero_grad()
if (i + 1) % (2 * args.controller_num_aggregate) == 0:
learning_rate = controller_optimizer.param_groups[0]['lr']
display = 'ctrl_step=' + str(i // args.controller_num_aggregate) + \
'\tloss=%.3f' % (loss_meter.val) + \
'\tent=%.2f' % (controller.sample_entropy.item()) + \
'\tlr=%.4f' % (learning_rate) + \
'\t|g|=%.4f' % (grad_norm.item()) + \
'\tacc=%.4f' % (val_acc_meter.val) + \
'\tbl=%.2f' % (baseline_meter.val) + \
'\ttime=%.2fit/s' % (1. / (end - start))
print(display)
vis_win['controller_reward'] = vis.line(
X=np.column_stack([epoch] * 2),
Y=np.column_stack([reward_meter.avg, baseline_meter.avg]),
win=vis_win['controller_reward'],
opts=dict(title='controller_reward', xlabel='Iteration', ylabel='Reward'),
update='append' if epoch > 0 else None)
vis_win['controller_acc'] = vis.line(
X=np.array([epoch]),
Y=np.array([val_acc_meter.avg]),
win=vis_win['controller_acc'],
opts=dict(title='controller_acc', xlabel='Iteration', ylabel='Accuracy'),
update='append' if epoch > 0 else None)
vis_win['controller_loss'] = vis.line(
X=np.array([epoch]),
Y=np.array([loss_meter.avg]),
win=vis_win['controller_loss'],
opts=dict(title='controller_loss', xlabel='Iteration', ylabel='Loss'),
update='append' if epoch > 0 else None)
shared_cnn.train()
return baseline
def evaluate_model(epoch, controller, shared_cnn, data_loaders, n_samples=10):
"""Print the validation and test accuracy for a controller and shared_cnn.
Args:
epoch: Current epoch.
controller: Controller module that generates architectures to be trained.
shared_cnn: CNN that contains all possible architectures, with shared weights.
data_loaders: Dict containing data loaders.
n_samples: Number of architectures to test when looking for the best one.
Returns: Nothing.
"""
controller.eval()
shared_cnn.eval()
print('Here are ' + str(n_samples) + ' architectures:')
best_arc, _ = get_best_arc(controller, shared_cnn, data_loaders, n_samples, verbose=True)
valid_loader = data_loaders['valid_subset']
test_loader = data_loaders['test_dataset']
valid_acc = get_eval_accuracy(valid_loader, shared_cnn, best_arc)
test_acc = get_eval_accuracy(test_loader, shared_cnn, best_arc)
print('Epoch ' + str(epoch) + ': Eval')
print('valid_accuracy: %.4f' % (valid_acc))
print('test_accuracy: %.4f' % (test_acc))
controller.train()
shared_cnn.train()
def get_best_arc(controller, shared_cnn, data_loaders, n_samples=10, verbose=False):
"""Evaluate several architectures and return the best performing one.
Args:
controller: Controller module that generates architectures to be trained.
shared_cnn: CNN that contains all possible architectures, with shared weights.
data_loaders: Dict containing data loaders.
n_samples: Number of architectures to test when looking for the best one.
verbose: If True, display the architecture and resulting validation accuracy.
Returns:
best_arc: The best performing architecture.
best_vall_acc: Accuracy achieved on the best performing architecture.
All architectures are evaluated on the same minibatch from the validation set.
"""
controller.eval()
shared_cnn.eval()
valid_loader = data_loaders['valid_subset']
images, labels = next(iter(valid_loader))
images = images.cuda()
labels = labels.cuda()
arcs = []
val_accs = []
for i in range(n_samples):
with torch.no_grad():
controller() # perform forward pass to generate a new architecture
sample_arc = controller.sample_arc
arcs.append(sample_arc)
with torch.no_grad():
pred = shared_cnn(images, sample_arc)
val_acc = torch.mean((torch.max(pred, 1)[1] == labels).type(torch.float))
val_accs.append(val_acc.item())
if verbose:
print_arc(sample_arc)
print('val_acc=' + str(val_acc.item()))
print('-' * 80)
best_iter = np.argmax(val_accs)
best_arc = arcs[best_iter]
best_val_acc = val_accs[best_iter]
controller.train()
shared_cnn.train()
return best_arc, best_val_acc
def get_eval_accuracy(loader, shared_cnn, sample_arc):
"""Evaluate a given architecture.
Args:
loader: A single data loader.
shared_cnn: CNN that contains all possible architectures, with shared weights.
sample_arc: The architecture to use for the evaluation.
Returns:
acc: Average accuracy.
"""
total = 0.
acc_sum = 0.
for (images, labels) in loader:
images = images.cuda()
labels = labels.cuda()
with torch.no_grad():
pred = shared_cnn(images, sample_arc)
acc_sum += torch.sum((torch.max(pred, 1)[1] == labels).type(torch.float))
total += pred.shape[0]
acc = acc_sum / total
return acc.item()
def print_arc(sample_arc):
"""Display a sample architecture in a readable format.
Args:
sample_arc: The architecture to display.
Returns: Nothing.
"""
for key, value in sample_arc.items():
if len(value) == 1:
branch_type = value[0].cpu().numpy().tolist()
print('[' + ' '.join(str(n) for n in branch_type) + ']')
else:
branch_type = value[0].cpu().numpy().tolist()
skips = value[1].cpu().numpy().tolist()
print('[' + ' '.join(str(n) for n in (branch_type + skips)) + ']')
def train_enas(start_epoch,
controller,
shared_cnn,
data_loaders,
shared_cnn_optimizer,
controller_optimizer,
shared_cnn_scheduler):
"""Perform architecture search by training a controller and shared_cnn.
Args:
start_epoch: Epoch to begin on.
controller: Controller module that generates architectures to be trained.
shared_cnn: CNN that contains all possible architectures, with shared weights.
data_loaders: Dict containing data loaders.
shared_cnn_optimizer: Optimizer for the shared_cnn.
controller_optimizer: Optimizer for the controller.
shared_cnn_scheduler: Learning rate schedular for shared_cnn_optimizer
Returns: Nothing.
"""
baseline = None
for epoch in range(start_epoch, args.num_epochs):
train_shared_cnn(epoch,
controller,
shared_cnn,
data_loaders,
shared_cnn_optimizer)
baseline = train_controller(epoch,
controller,
shared_cnn,
data_loaders,
controller_optimizer,
baseline)
if epoch % args.eval_every_epochs == 0:
evaluate_model(epoch, controller, shared_cnn, data_loaders)
shared_cnn_scheduler.step(epoch)
state = {'epoch': epoch + 1,
'args': args,
'shared_cnn_state_dict': shared_cnn.state_dict(),
'controller_state_dict': controller.state_dict(),
'shared_cnn_optimizer': shared_cnn_optimizer.state_dict(),
'controller_optimizer': controller_optimizer.state_dict()}
filename = 'checkpoints/' + args.output_filename + '.pth.tar'
torch.save(state, filename)
def train_fixed(start_epoch,
controller,
shared_cnn,
data_loaders):
"""Train a fixed cnn architecture.
Args:
start_epoch: Epoch to begin on.
controller: Controller module that generates architectures to be trained.
shared_cnn: CNN that contains all possible architectures, with shared weights.
data_loaders: Dict containing data loaders.
Returns: Nothing.
Given a fully trained controller and shared_cnn, we sample many architectures,
and then train a new cnn from scratch using the best architecture we found.
We change the number of filters in the new cnn such that the final layer
has 512 channels.
"""
best_arc, best_val_acc = get_best_arc(controller, shared_cnn, data_loaders, n_samples=100, verbose=False)
print('Best architecture:')
print_arc(best_arc)
print('Validation accuracy: ' + str(best_val_acc))
fixed_cnn = SharedCNN(num_layers=args.child_num_layers,
num_branches=args.child_num_branches,
out_filters=512 // 4, # args.child_out_filters
keep_prob=args.child_keep_prob,
fixed_arc=best_arc)
fixed_cnn = fixed_cnn.cuda()
fixed_cnn_optimizer = torch.optim.SGD(params=fixed_cnn.parameters(),
lr=args.child_lr_max,
momentum=0.9,
nesterov=True,
weight_decay=args.child_l2_reg)
fixed_cnn_scheduler = CosineAnnealingLR(optimizer=fixed_cnn_optimizer,
T_max=args.child_lr_T,
eta_min=args.child_lr_min)
test_loader = data_loaders['test_dataset']
for epoch in range(args.num_epochs):
train_shared_cnn(epoch,
controller, # not actually used in training the fixed_cnn
fixed_cnn,
data_loaders,
fixed_cnn_optimizer,
best_arc)
if epoch % args.eval_every_epochs == 0:
test_acc = get_eval_accuracy(test_loader, fixed_cnn, best_arc)
print('Epoch ' + str(epoch) + ': Eval')
print('test_accuracy: %.4f' % (test_acc))
fixed_cnn_scheduler.step(epoch)
state = {'epoch': epoch + 1,
'args': args,
'best_arc': best_arc,
'fixed_cnn_state_dict': shared_cnn.state_dict(),
'fixed_cnn_optimizer': fixed_cnn_optimizer.state_dict()}
filename = 'checkpoints/' + args.output_filename + '_fixed.pth.tar'
torch.save(state, filename)
def main():
global args
np.random.seed(args.seed)
torch.cuda.manual_seed(args.seed)
if args.fixed_arc:
sys.stdout = Logger(filename='logs/' + args.output_filename + '_fixed.log')
else:
sys.stdout = Logger(filename='logs/' + args.output_filename + '.log')
print(args)
data_loaders = load_datasets()
controller = Controller(search_for=args.search_for,
search_whole_channels=True,
num_layers=args.child_num_layers,
num_branches=args.child_num_branches,
out_filters=args.child_out_filters,
lstm_size=args.controller_lstm_size,
lstm_num_layers=args.controller_lstm_num_layers,
tanh_constant=args.controller_tanh_constant,
temperature=None,
skip_target=args.controller_skip_target,
skip_weight=args.controller_skip_weight)
controller = controller.cuda()
shared_cnn = SharedCNN(num_layers=args.child_num_layers,
num_branches=args.child_num_branches,
out_filters=args.child_out_filters,
keep_prob=args.child_keep_prob)
shared_cnn = shared_cnn.cuda()
# https://github.com/melodyguan/enas/blob/master/src/utils.py#L218
controller_optimizer = torch.optim.Adam(params=controller.parameters(),
lr=args.controller_lr,
betas=(0.0, 0.999),
eps=1e-3)
# https://github.com/melodyguan/enas/blob/master/src/utils.py#L213
shared_cnn_optimizer = torch.optim.SGD(params=shared_cnn.parameters(),
lr=args.child_lr_max,
momentum=0.9,
nesterov=True,
weight_decay=args.child_l2_reg)
# https://github.com/melodyguan/enas/blob/master/src/utils.py#L154
shared_cnn_scheduler = CosineAnnealingLR(optimizer=shared_cnn_optimizer,
T_max=args.child_lr_T,
eta_min=args.child_lr_min)
if args.resume:
if os.path.isfile(args.resume):
print("Loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume)
start_epoch = checkpoint['epoch']
# args = checkpoint['args']
shared_cnn.load_state_dict(checkpoint['shared_cnn_state_dict'])
controller.load_state_dict(checkpoint['controller_state_dict'])
shared_cnn_optimizer.load_state_dict(checkpoint['shared_cnn_optimizer'])
controller_optimizer.load_state_dict(checkpoint['controller_optimizer'])
shared_cnn_scheduler.optimizer = shared_cnn_optimizer # Not sure if this actually works
print("Loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
else:
raise ValueError("No checkpoint found at '{}'".format(args.resume))
else:
start_epoch = 0
if not args.fixed_arc:
train_enas(start_epoch,
controller,
shared_cnn,
data_loaders,
shared_cnn_optimizer,
controller_optimizer,
shared_cnn_scheduler)
else:
assert args.resume != '', 'A pretrained model should be used when training a fixed architecture.'
train_fixed(start_epoch,
controller,
shared_cnn,
data_loaders)
if __name__ == "__main__":
main()