-
Notifications
You must be signed in to change notification settings - Fork 5
/
options.py
143 lines (121 loc) · 7.65 KB
/
options.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
""" Options
This script is largely based on junyanz/pytorch-CycleGAN-and-pix2pix.
Returns:
[argparse]: Class containing argparse
"""
import argparse
import os
import torch
# pylint: disable=C0103,C0301,R0903,W0622
class Options():
"""Options class
Returns:
[argparse]: argparse containing train and test options
"""
def __init__(self):
##
#
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
##
# Base
self.parser.add_argument('--dataset', default='cifar10', help='folder | cifar10 | mnist ')
self.parser.add_argument('--dataroot', default='', help='path to dataset')
self.parser.add_argument('--path', default='', help='path to the folder or image to be predicted.')
self.parser.add_argument('--batchsize', type=int, default=1, help='input batch size')
self.parser.add_argument('--workers', type=int, help='number of data loading workers', default=8)
self.parser.add_argument('--droplast', action='store_true', default=True, help='Drop last batch size.')
self.parser.add_argument('--isize', type=int, default=64, help='input image size.')
self.parser.add_argument('--nc', type=int, default=3, help='input image channels')
self.parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector')
self.parser.add_argument('--ngf', type=int, default=64)
self.parser.add_argument('--ndf', type=int, default=64)
self.parser.add_argument('--extralayers', type=int, default=0, help='Number of extra layers on gen and disc')
self.parser.add_argument('--device', type=str, default='gpu', help='Device: gpu | cpu')
self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
self.parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use')
self.parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment')
self.parser.add_argument('--model', type=str, default='skipganomaly',
help='chooses which model to use. skipganomaly')
self.parser.add_argument('--display_server', type=str, default="http://localhost",
help='visdom server of the web display')
self.parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
self.parser.add_argument('--display_id', type=int, default=0, help='window id of the web display')
self.parser.add_argument('--display', action='store_true', help='Use visdom.')
self.parser.add_argument('--verbose', action='store_true', help='Print the training and model details.')
self.parser.add_argument('--outf', default='./output', help='folder to output images and model checkpoints')
self.parser.add_argument('--manualseed', default=1337, type=int, help='manual seed')
self.parser.add_argument('--abnormal_class', default='cat', help='Anomaly class idx for mnist and cifar datasets')
self.parser.add_argument('--metric', type=str, default='roc', help='Evaluation metric.')
##
# Train
self.parser.add_argument('--print_freq', type=int, default=100,
help='frequency of showing training results on console')
self.parser.add_argument('--save_image_freq', type=int, default=100,
help='frequency of saving real and fake images')
self.parser.add_argument('--save_test_images', default='true', action='store_true',
help='Save test images for demo.')
self.parser.add_argument('--load_weights', action='store_true', help='Load the pretrained weights')
self.parser.add_argument('--resume', default='', help="path to checkpoints (to continue training)")
self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
self.parser.add_argument('--iter', type=int, default=0, help='Start from iteration i')
self.parser.add_argument('--niter', type=int, default=15, help='number of epochs to train for')
self.parser.add_argument('--niter_decay', type=int, default=100,
help='# of iter to linearly decay learning rate to zero')
self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
self.parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
self.parser.add_argument('--w_adv', type=float, default=1, help='Weight for adversarial loss. default=1')
self.parser.add_argument('--w_con', type=float, default=50, help='Weight for reconstruction loss. default=50')
self.parser.add_argument('--w_lat', type=float, default=1, help='Weight for latent space loss. default=1')
self.parser.add_argument('--lr_policy', type=str, default='lambda', help='lambda|step|plateau')
self.parser.add_argument('--gen', type=str, default='fca', help='Unet|Unet_att|Unet_att_CBAM|UnetDS_att_CBAM|transformer|SwinUnet')
self.parser.add_argument('--lr_decay_iters', type=int, default=50,
help='multiply by a gamma every lr_decay_iters iterations')
self.parser.add_argument('--uselat', type=int, default=1)
self.parser.add_argument('--patchsize', type=int, default=4)
self.parser.add_argument('--windowsize', type=int, default=2)
self.parser.add_argument('--num_layers', type=int, default=10)
self.parser.add_argument('--dim', type=int, default=64)
self.parser.add_argument('--gap', type=int, default=3)
self.parser.add_argument('--depth', type=int, default=3)
self.parser.add_argument('--drop_rate', type=float, default=0.0)
# for mvtec
self.parser.add_argument('--attention', type=str, default='swintf')
self.isTrain = True
self.opt = None
def parse(self):
""" Parse Arguments.
"""
self.opt = self.parser.parse_args()
self.opt.isTrain = self.isTrain # train or test
# get gpu ids
str_ids = self.opt.gpu_ids.split(',')
self.opt.gpu_ids = []
for str_id in str_ids:
id = int(str_id)
if id >= 0:
self.opt.gpu_ids.append(id)
# set gpu ids
if len(self.opt.gpu_ids) > 0:
torch.cuda.set_device(self.opt.gpu_ids[0])
args = vars(self.opt)
if self.opt.verbose:
print('------------ Options -------------')
for k, v in sorted(args.items()):
print('%s: %s' % (str(k), str(v)))
print('-------------- End ----------------')
# save to the disk
if self.opt.name == 'experiment_name':
self.opt.name = "%s/%s" % (self.opt.model, self.opt.dataset)
expr_dir = os.path.join(self.opt.outf, self.opt.name, 'train')
test_dir = os.path.join(self.opt.outf, self.opt.name, 'test')
if not os.path.isdir(expr_dir):
os.makedirs(expr_dir)
if not os.path.isdir(test_dir):
os.makedirs(test_dir)
file_name = os.path.join(expr_dir, 'opt.txt')
with open(file_name, 'wt') as opt_file:
opt_file.write('------------ Options -------------\n')
for k, v in sorted(args.items()):
opt_file.write('%s: %s\n' % (str(k), str(v)))
opt_file.write('-------------- End ----------------\n')
return self.opt