-
Notifications
You must be signed in to change notification settings - Fork 40
/
Copy pathmain.py
94 lines (68 loc) · 2.86 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import matlab.engine
import os
import logging
import argparse
import numpy as np
from train_and_evaluate import evaluate, train
from net import Generator
import utils
import torch
# start matlab engine
eng = matlab.engine.start_matlab()
# RCWA path
eng.addpath(eng.genpath('/home/users/jiangjq/Desktop/reticolo_allege'));
eng.addpath(eng.genpath('solvers'));
# parser
parser = argparse.ArgumentParser()
parser.add_argument('--output_dir', default='results',
help="Results folder")
parser.add_argument('--wavelength', default=None)
parser.add_argument('--angle', default=None)
parser.add_argument('--restore_from', default=None,
help="Optional, directory or file containing weights to reload before training")
if __name__ == '__main__':
# Load the directory from commend line
args = parser.parse_args()
# Set the logger
utils.set_logger(os.path.join(args.output_dir, 'train.log'))
# Load parameters from json file
json_path = os.path.join(args.output_dir,'Params.json')
assert os.path.isfile(json_path), "No json file found at {}".format(json_path)
params = utils.Params(json_path)
# Add attributes to params
params.output_dir = args.output_dir
params.cuda = torch.cuda.is_available()
params.restore_from = args.restore_from
params.numIter = int(params.numIter)
params.noise_dims = int(params.noise_dims)
params.gkernlen = int(params.gkernlen)
params.step_size = int(params.step_size)
if args.wavelength is not None:
params.wavelength = int(args.wavelength)
if args.angle is not None:
params.angle = int(args.angle)
# make directory
os.makedirs(args.output_dir + '/outputs', exist_ok = True)
os.makedirs(args.output_dir + '/model', exist_ok = True)
os.makedirs(args.output_dir + '/figures/histogram', exist_ok = True)
os.makedirs(args.output_dir + '/figures/deviceSamples', exist_ok = True)
# Define the models
generator = Generator(params)
# Move to gpu if possible
if params.cuda:
generator.cuda()
# Define the optimizer
optimizer = torch.optim.Adam(generator.parameters(), lr=params.lr, betas=(params.beta1, params.beta2))
# Define the scheduler
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=params.step_size, gamma = params.gamma)
# Load model data
if args.restore_from is not None :
params.checkpoint = utils.load_checkpoint(restore_from, generator, optimizer, scheduler)
logging.info('Model data loaded')
# Train the model and save
if params.numIter != 0 :
logging.info('Start training')
train(generator, optimizer, scheduler, eng, params)
# Generate images and save
logging.info('Start generating devices')
evaluate(generator, eng, numImgs=500, params=params)