-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinference_classifier.py
65 lines (50 loc) · 1.94 KB
/
inference_classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os
import tensorflow as tf
import numpy as np
from utils.helpers import get_proper_fn
from utils.dataset_loaders import load_dataset
from utils.classifier_fns import model_fn
# don't pu image manipulation code inside serving_input_fn() -> it will not work!!
def serving_input_fn():
image = tf.placeholder(dtype=tf.float32, shape=[None, 28, 28, 1], name='images')
inputs = {'images': image}
return tf.estimator.export.ServingInputReceiver(inputs, inputs)
def main():
# get testing network
network_module = 'resnet.network_resnet'
network_name = 'resnet83'
network_fn = get_proper_fn(network_module, network_name)
dataset_name = 'mnist'
model_dir = os.path.join('./models', 'cnn', dataset_name, network_name)
# grab data
trainset, testset, input_size, n_classes = load_dataset(dataset_name)
# create the Estimator
model = tf.estimator.Estimator(
model_fn=model_fn,
model_dir=model_dir,
config=None,
params={
'network_fn': network_fn,
'n_classes': n_classes,
'weight_decay': 1e-4,
# 'learning_rate': learning_rate,
},
# warm_start_from=ws,
)
# predict model
estimator_predictor = tf.contrib.predictor.from_estimator(model, serving_input_fn)
test_images = testset['images'].astype(np.float32)
test_images = test_images / 255.0
test_images = (test_images - 0.5) * 2.0
for ii in range(20):
# print('Running example index: {:d}'.format(ii))
test_label = testset['labels'][ii]
test_image = test_images[ii]
test_image = np.array(test_image)
test_image = np.expand_dims(test_image, axis=0)
p = estimator_predictor({'images': test_image})
print('GT: {}, predicted: {}'.format(test_label, p['predicted_class'][0]))
# print('GT: {}, predicted: {}'.format(test_label, p['output'][0]))
return
if __name__ == '__main__':
main()