forked from junyanz/pytorch-CycleGAN-and-pix2pix
-
Notifications
You must be signed in to change notification settings - Fork 0
/
base_model.py
159 lines (136 loc) · 6.1 KB
/
base_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import os
import torch
from collections import OrderedDict
from . import networks
class BaseModel():
# modify parser to add command line options,
# and also change the default values if needed
@staticmethod
def modify_commandline_options(parser, is_train):
return parser
def name(self):
return 'BaseModel'
def initialize(self, opt):
self.opt = opt
self.gpu_ids = opt.gpu_ids
self.isTrain = opt.isTrain
self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
if opt.resize_or_crop != 'scale_width':
torch.backends.cudnn.benchmark = True
self.loss_names = []
self.model_names = []
self.visual_names = []
self.image_paths = []
def set_input(self, input):
self.input = input
def forward(self):
pass
# load and print networks; create schedulers
def setup(self, opt, parser=None):
if self.isTrain:
self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
if not self.isTrain or opt.continue_train:
self.load_networks(opt.epoch)
self.print_networks(opt.verbose)
# make models eval mode during test time
def eval(self):
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, 'net' + name)
net.eval()
# used in test time, wrapping `forward` in no_grad() so we don't save
# intermediate steps for backprop
def test(self):
with torch.no_grad():
self.forward()
# get image paths
def get_image_paths(self):
return self.image_paths
def optimize_parameters(self):
pass
# update learning rate (called once every epoch)
def update_learning_rate(self):
for scheduler in self.schedulers:
scheduler.step()
lr = self.optimizers[0].param_groups[0]['lr']
print('learning rate = %.7f' % lr)
# return visualization images. train.py will display these images, and save the images to a html
def get_current_visuals(self):
visual_ret = OrderedDict()
for name in self.visual_names:
if isinstance(name, str):
visual_ret[name] = getattr(self, name)
return visual_ret
# return traning losses/errors. train.py will print out these errors as debugging information
def get_current_losses(self):
errors_ret = OrderedDict()
for name in self.loss_names:
if isinstance(name, str):
# float(...) works for both scalar tensor and float number
errors_ret[name] = float(getattr(self, 'loss_' + name))
return errors_ret
# save models to the disk
def save_networks(self, epoch):
for name in self.model_names:
if isinstance(name, str):
save_filename = '%s_net_%s.pth' % (epoch, name)
save_path = os.path.join(self.save_dir, save_filename)
net = getattr(self, 'net' + name)
if len(self.gpu_ids) > 0 and torch.cuda.is_available():
torch.save(net.module.cpu().state_dict(), save_path)
net.cuda(self.gpu_ids[0])
else:
torch.save(net.cpu().state_dict(), save_path)
def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
key = keys[i]
if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
if module.__class__.__name__.startswith('InstanceNorm') and \
(key == 'running_mean' or key == 'running_var'):
if getattr(module, key) is None:
state_dict.pop('.'.join(keys))
if module.__class__.__name__.startswith('InstanceNorm') and \
(key == 'num_batches_tracked'):
state_dict.pop('.'.join(keys))
else:
self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
# load models from the disk
def load_networks(self, epoch):
for name in self.model_names:
if isinstance(name, str):
load_filename = '%s_net_%s.pth' % (epoch, name)
load_path = os.path.join(self.save_dir, load_filename)
net = getattr(self, 'net' + name)
if isinstance(net, torch.nn.DataParallel):
net = net.module
print('loading the model from %s' % load_path)
# if you are using PyTorch newer than 0.4 (e.g., built from
# GitHub source), you can remove str() on self.device
state_dict = torch.load(load_path, map_location=str(self.device))
if hasattr(state_dict, '_metadata'):
del state_dict._metadata
# patch InstanceNorm checkpoints prior to 0.4
for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
net.load_state_dict(state_dict)
# print network information
def print_networks(self, verbose):
print('---------- Networks initialized -------------')
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, 'net' + name)
num_params = 0
for param in net.parameters():
num_params += param.numel()
if verbose:
print(net)
print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
print('-----------------------------------------------')
# set requies_grad=Fasle to avoid computation
def set_requires_grad(self, nets, requires_grad=False):
if not isinstance(nets, list):
nets = [nets]
for net in nets:
if net is not None:
for param in net.parameters():
param.requires_grad = requires_grad