-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_pascal_base.py
68 lines (55 loc) · 2.02 KB
/
train_pascal_base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""
Reference: https://github.com/scaelles/OSVOS-TensorFlow
"""
import os
import tensorflow as tf
import random
import numpy as np
from utils import models
from utils.load_data_pascal_base import Dataset
from utils.logger import create_logger
# seed
seed = random.randint(1, 100000)
# seed = 0
tf.random.set_seed(seed)
random.seed(seed)
np.random.seed(seed)
# User defined parameters
gpu_id = 0
# Training parameters
imagenet_ckpt = 'weights/imagenet_pretrain_weights/xception_ckpt_new'
logs_path = os.path.join('weights', 'pascal_base_train_weights')
if not os.path.exists(logs_path):
os.mkdir(logs_path)
store_memory = True
data_aug = True
pretrained_model = True
supervision = 1
iter_mean_grad = 10
max_training_iters = 25000
save_step = 5000
test_image = None
display_step = 10
# learning rate setting
learning_rate = 1e-6
batch_size = 1
# log some important info
logger = create_logger(logs_path)
logger.info('The random seed is {}'.format(seed))
logger.info('The max training iteration is {}'.format(max_training_iters))
logger.info('The supervision mode is {}'.format(supervision))
logger.info('Data augmentation is {}'.format(data_aug))
# Define Dataset
# use this one for training
train_file = 'datasets/pretrain_benchmark_reduced.txt'
# # small dataset txt file for fast debugging
# train_file = 'datasets/test_algorithm_pretrain_benchmark_reduced.txt'
dataset = Dataset(train_file, None, './datasets/pascal_extension_dataset',
store_memory=store_memory, data_aug=data_aug)
# Train the network
with tf.Graph().as_default():
with tf.device('/gpu:' + str(gpu_id)):
global_step = tf.Variable(0, name='global_step', trainable=False)
models.pre_train(dataset, imagenet_ckpt, supervision, learning_rate, logs_path, max_training_iters, save_step,
display_step, global_step, logger, finetune=0, iter_mean_grad=iter_mean_grad, test_image_path=test_image,
ckpt_name='pascal_train', dropout_rate=0.5, batch_size=batch_size, pretrained_model=pretrained_model)