-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
62 lines (51 loc) · 2.67 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import tensorflow as tf
import json
import model
import csv_io
img_color = ['blue', 'green', 'red', 'yellow']
batch_size = 32
def gen_fn(id, label, img_dir):
color_img = [tf.image.decode_png(tf.read_file(img_dir + id + "_" + color + ".png"), dtype = tf.uint8, channels = 1) for color in img_color]
for current_color_img in color_img:
current_color_img.set_shape([512, 512, 1])
color_img_reshaped = [tf.reshape(i, [512, 512]) for i in color_img]
# img = tf.stack(color_img_reshaped, axis=2)
img = tf.image.resize_images(tf.stack(color_img_reshaped, axis=2), [224, 224])
normalized_img = tf.divide(tf.cast(img, dtype = tf.float32), tf.convert_to_tensor(255.0))
return (normalized_img, label)
def train_input_fn(img_id, img_dir, labels, batch_size):
dataset = tf.data.Dataset.from_tensor_slices((img_id, labels)).map(lambda id, label:gen_fn(id, label, img_dir))
dataset = dataset.shuffle(512).repeat().batch(batch_size)
return dataset
def main(argv):
config_file_name = 'train_config.json'
try:
config_file = open(config_file_name, 'r')
config_data = json.loads(config_file.read())
config_file.close()
except IOError:
print("Fail to load configuration file")
return 0
if 'csv_file' not in config_data:
print("csv file not available")
if 'img_dir' not in config_data:
print("image directory not available")
if 'model_dir' not in config_data:
print("model directory not available")
if 'ckpt_dir' not in config_data:
print("checkpoint directory not available")
if 'csv_file' not in config_data or 'img_dir' not in config_data or 'model_dir' not in config_data or 'ckpt_dir' not in config_data:
return 0
config_data['img_dir'] += "/" if config_data['img_dir'][-1] else ""
img_id, label = csv_io.csv_read(config_data['csv_file'])
run_config = tf.estimator.RunConfig(save_checkpoints_steps=2000, save_checkpoints_secs=None, keep_checkpoint_max = 1)
classifier = tf.estimator.Estimator(model_fn = model.resnet_50_model_fn, config = run_config, model_dir = config_data['ckpt_dir'])
classifier.train(input_fn = lambda:train_input_fn(img_id, config_data['img_dir'], label, batch_size), steps = 2000)
# classifier.export_saved_model(export_dir_base=config_data['model_dir'],
# serving_input_receiver_fn=tf.estimator.export.build_raw_serving_input_receiver_fn({"features" : tf.placeholder(dtype=tf.float32)}))
if __name__ == "__main__":
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.Session(config=config)
tf.logging.set_verbosity(tf.logging.INFO)
tf.app.run()