forked from JackonYang/captcha-tensorflow
-
Notifications
You must be signed in to change notification settings - Fork 0
/
softmax_with_log.py
115 lines (91 loc) · 4.11 KB
/
softmax_with_log.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# -*- coding:utf-8 -*-
import argparse
import datetime
import sys
import tensorflow as tf
import input_data
IMAGE_WIDTH = 60
IMAGE_HEIGHT = 100
IMAGE_SIZE = IMAGE_WIDTH * IMAGE_HEIGHT
LABEL_SIZE = 10 # range(0, 10)
MAX_STEPS = 10000
BATCH_SIZE = 100
LOG_DIR = 'log/regression-run-%s' % datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
FLAGS = None
def variable_summaries(var):
"""Attach a lot of summaries to a Tensor (for TensorBoard visualization)."""
with tf.name_scope('summaries'):
mean = tf.reduce_mean(var)
tf.summary.scalar('mean', mean)
with tf.name_scope('stddev'):
stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)))
tf.summary.scalar('stddev', stddev)
tf.summary.scalar('max', tf.reduce_max(var))
tf.summary.scalar('min', tf.reduce_min(var))
tf.summary.histogram('histogram', var)
def main(_):
# load data
train_data, test_data = input_data.load_data_1char(FLAGS.data_dir)
print 'data loaded. train images: %s. test images: %s' % (train_data.images.shape[0], test_data.images.shape[0])
# variable in the graph for input data
with tf.name_scope('input'):
x = tf.placeholder(tf.float32, [None, IMAGE_SIZE])
y_ = tf.placeholder(tf.float32, [None, LABEL_SIZE])
variable_summaries(x)
variable_summaries(y_)
# must be 4-D with shape `[batch_size, height, width, channels]`
images_shaped_input = tf.reshape(x, [-1, IMAGE_HEIGHT, IMAGE_WIDTH, 1])
tf.summary.image('input', images_shaped_input, max_outputs=LABEL_SIZE)
# define the model
# Adding a name scope ensures logical grouping of the layers in the graph.
with tf.name_scope('linear_model'):
with tf.name_scope('W'):
W = tf.Variable(tf.zeros([IMAGE_SIZE, LABEL_SIZE]))
variable_summaries(W)
with tf.name_scope('b'):
b = tf.Variable(tf.zeros([LABEL_SIZE]))
variable_summaries(b)
with tf.name_scope('y'):
y = tf.matmul(x, W) + b
tf.summary.histogram('y', y)
# Define loss and optimizer
# Returns:
# A 1-D `Tensor` of length `batch_size`
# of the same type as `logits` with the softmax cross entropy loss.
with tf.name_scope('loss'):
diff = tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y)
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(tf.reduce_mean(diff))
variable_summaries(diff)
# forword prop
predict = tf.argmax(y, axis=1)
expect = tf.argmax(y_, axis=1)
# evaluate accuracy
with tf.name_scope('evaluate_accuracy'):
correct_prediction = tf.equal(predict, expect)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
variable_summaries(accuracy)
with tf.Session() as sess:
merged = tf.summary.merge_all()
train_writer = tf.summary.FileWriter(LOG_DIR + '/train', sess.graph)
tf.global_variables_initializer().run()
# Train
for i in range(MAX_STEPS):
batch_xs, batch_ys = train_data.next_batch(BATCH_SIZE)
train_summary, _ = sess.run([merged, train_step], feed_dict={x: batch_xs, y_: batch_ys})
train_writer.add_summary(train_summary, i)
if i % 100 == 0:
# Test trained model
test_summary, r = sess.run([merged, accuracy], feed_dict={x: test_data.images, y_: test_data.labels})
train_writer.add_summary(test_summary, i)
print 'step = %s, accuracy = %.2f%%' % (i, r * 100)
train_writer.close()
# final check after looping
test_summary, r_test = sess.run([merged, accuracy], feed_dict={x: test_data.images, y_: test_data.labels})
train_writer.add_summary(test_summary, i)
print 'testing accuracy = %.2f%%' % (r_test * 100, )
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', type=str, default='images/one-char',
help='Directory for storing input data')
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)