-
Notifications
You must be signed in to change notification settings - Fork 4
/
main.py
128 lines (99 loc) · 5.03 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import csv
import time
import json
import argparse
import numpy as np
import tensorflow as tf
import mnist_data
from model import MNISTcnn
#from tensorflow.examples.tutorials.mnist import input_data
def predict(sess, x, keep_prob, pred, images, output_file):
feed_dict = {x:images, keep_prob: 1.0}
prediction = sess.run(pred, feed_dict=feed_dict)
with open(output_file, "w") as file:
writer = csv.writer(file, delimiter = ",")
writer.writerow(["id","label"])
for i in range(len(prediction)):
writer.writerow([str(i), str(prediction[i])])
print("Output prediction: {0}". format(output_file))
def train(args, data):
obs_shape = data.train.get_observation_size() # e.g. a tuple (28,28,1)
assert len(obs_shape) == 3, 'assumed right now'
#obs_shape = (28,28,1)
num_class = data.train.labels.shape[1]
x = tf.placeholder(tf.float32, shape=(None,) + obs_shape)
y = tf.placeholder(tf.float32, (None, num_class))
model = MNISTcnn(x, y, args)
optimizer = tf.train.AdamOptimizer(1e-4).minimize(model.loss)
saver = tf.train.Saver(tf.trainable_variables())
with tf.Session() as sess:
print('Starting training')
sess.run(tf.global_variables_initializer())
if args.load_params:
ckpt_file = os.path.join(args.ckpt_dir, 'mnist_model.ckpt')
print('Restoring parameters from', ckpt_file)
saver.restore(sess, ckpt_file)
num_batches = data.train.num_examples // args.batch_size
if args.val_size > 0:
validation = True
val_num_batches = data.validation.num_examples // args.batch_size
else:
validation = False
for epoch in range(args.epochs):
begin = time.time()
# train
train_accuracies = []
for i in range(num_batches):
batch = data.train.next_batch(args.batch_size)
feed_dict = {x:batch[0], y:batch[1], model.keep_prob: 0.5}
_, acc = sess.run([optimizer, model.accuracy], feed_dict=feed_dict)
train_accuracies.append(acc)
train_acc_mean = np.mean(train_accuracies)
# compute loss over validation data
if validation:
val_accuracies = []
for i in range(val_num_batches):
batch = data.validation.next_batch(args.batch_size)
feed_dict = {x:batch[0], y:batch[1], model.keep_prob: 1.0}
acc = sess.run(model.accuracy, feed_dict=feed_dict)
val_accuracies.append(acc)
val_acc_mean = np.mean(val_accuracies)
# log progress to console
print("Epoch %d, time = %ds, train accuracy = %.4f, validation accuracy = %.4f" % (epoch, time.time()-begin, train_acc_mean, val_acc_mean))
else:
print("Epoch %d, time = %ds, train accuracy = %.4f" % (epoch, time.time()-begin, train_acc_mean))
sys.stdout.flush()
if (epoch + 1) % 10 == 0:
ckpt_file = os.path.join(args.ckpt_dir, 'mnist_model.ckpt')
saver.save(sess, ckpt_file)
ckpt_file = os.path.join(args.ckpt_dir, 'mnist_model.ckpt')
saver.save(sess, ckpt_file)
# predict test data
predict(sess, x, model.keep_prob, model.pred, data.test.images, args.output)
# origiinal test data from 'http://yann.lecun.com/exdb/mnist/'
"""
acc = sess.run(model.accuracy, feed_dict={x: data.test.images, y: data.test.labels, model.keep_prob: 1.0})
print("test accuracy %g"%acc)
"""
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-d', '--data_dir', type=str, default='mnist_data/', help='Directory for storing input data')
parser.add_argument('-c', '--ckpt_dir', type=str, default='ckpts/', help='Directory for parameter checkpoints')
parser.add_argument('-l', '--load_params', dest='load_params', action='store_true', help='Restore training from previous model checkpoint?')
parser.add_argument("-o", "--output", type=str, default='prediction.csv', help='Prediction filepath')
parser.add_argument('-e', '--epochs', type=int, default=30, help='How many epochs to run in total?')
parser.add_argument('-b', '--batch_size', type=int, default=50, help='Batch size during training per GPU')
parser.add_argument('-v', '--val_size', type=int, default=5000)
args = parser.parse_args()
# pretty print args
print('input args:\n', json.dumps(vars(args), indent=4, separators=(',',':')))
data = mnist_data.read_data_sets(args.data_dir, one_hot=True, reshape=False, validation_size=args.val_size)
#data = input_data.read_data_sets(args.data_dir, one_hot=True, reshape=False, validation_size=args.val_size)
if not os.path.exists(args.ckpt_dir):
os.makedirs(args.ckpt_dir)
train(args, data)