-
Notifications
You must be signed in to change notification settings - Fork 3
/
options.py
50 lines (40 loc) · 4.31 KB
/
options.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# coding: utf-8
from argparse import ArgumentParser
def get_options():
parser = ArgumentParser()
parser.add_argument('--use_cpu', action='store_true', help='for debug')
# dataset info
parser.add_argument('--dataset_dir', type=str, default='datasets/labels', help='directory of dataset for normal training')
parser.add_argument('--unlabel_dataset_dir', type=str, default='datasets/unlables', help='directory of dataset for semi-supervised trainig')
parser.add_argument('--augment_data', type=bool, default=True, help='if this flag is true, annotated data is augmented.')
parser.add_argument('--img_shape', type=tuple, default=(3, 256, 256), help='this tuple is limited int elements. the order is (channels, heights, widths)')
parser.add_argument('--class_num', type=int, default=5, help='target object class of semantic segmentation')
# training hyper-parameter
parser.add_argument('--batch_size', type=int, default=1, help='the number of traing samples utilized in one iteration')
parser.add_argument('--g_lr', type=float, default=2.5*1e-4, help='learning rate of adam optimizer in order to train generator')
parser.add_argument('--g_beta1', type=float, default=0.9, help='beta1 of adam optimizer in order to train generator')
parser.add_argument('--g_beta2', type=float, default=0.99, help='beta2 of adam optimizer in order to train generator')
parser.add_argument('--g_weight_decay', type=float, default=1e-4, help='it penalizes complexity to generater loss function in order to prevent overfitting. if its option is zero, not apply weight decay.')
parser.add_argument('--d_lr', type=float, default=1e-4, help='learning rate of adam optimizer in order to train discriminator')
parser.add_argument('--d_beta1', type=float, default=0.9, help='beta1 of adam optimizer in order to train discriminator')
parser.add_argument('--d_beta2', type=float, default=0.99, help='beta2 of adam optimizer in order to train discriminator')
parser.add_argument('--max_epoch', type=int, default=100, help='maximum roop number of dataset. by the way, one cycle of dataset is one epoch.')
parser.add_argument('--snap_interval_epoch', type=int, default=25, help='interval saving model parameter')
parser.add_argument('--img_interval_iteration', type=int, default=100, help='interval saving image sample')
parser.add_argument('--semi_ignit_iteration', type=int, default=10000, help='ignition iteration number of semi-supervised learning')
parser.add_argument('--lr_poly_train_period', type=float, default=0.4, help='if current iteration is larger than this, start learning rate poly.')
parser.add_argument('--lr_poly_power', type=float, default=0.9, help='strongth of learning rate poly')
parser.add_argument('--out_dir', type=str, default='result', help='directory of outputs')
# model archtecture hyper-parameter
parser.add_argument('--conv_norm', type=str, default='spectral_norm_hook',
help='convolution weight normalization type. [original] is typical convolution. [spectral_norm] is only used chainer.funcstions. [spectral_norm_hook] is based on chainer.link_hooks. there is details in spectral_norms.py.')
parser.add_argument('--ngf', type=int, default=64, help='dimension of hidden feature map at generator')
parser.add_argument('--ndf', type=int, default=64, help='dimension of hidden feature map at discriminator')
parser.add_argument('--aspp_nf', type=int, default=256, help='dimension of hidden feature map at ASPP archtecture')
# loss hyper parameter
parser.add_argument('--adv_loss_mode', type=str, default='hinge', help='adversarial loss approch. [bce] is binary-cross-entrtopy or softplus-loss. [mse] is mean-squered-error. [hinge] is hinge-loss')
parser.add_argument('--adv_coef', type=float, default=0.01, help='adversarial loss cofficient for annotated data')
parser.add_argument('--semi_adv_coef', type=float, default=0.001, help='adversarial loss cofficient for unlabeled data')
parser.add_argument('--semi_st_coef', type=float, default=0.1, help='self-teach loss cofficient for unlabeled data. this loss is used the vertual ground truth generated from discriminator')
parser.add_argument('--semi_threshold', type=float, default=0.2, help='this ratio is accuracy of confidence map generated by discriminator for self-teach loss')
return parser.parse_args()