forked from protonx-tf-03-projects/highway-networks
-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
26 lines (21 loc) · 944 Bytes
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import os
from argparse import ArgumentParser
import numpy as np
import tensorflow as tf
from data import build_dataset
if __name__ == "__main__":
home_dir = os.getcwd()
parser = ArgumentParser()
parser.add_argument("--model-folder", default='{}/model/highway/'.format(home_dir), type=str, required=True)
parser.add_argument("--image-index", default=0, type=int)
args = parser.parse_args()
# Loading Model
highway = tf.keras.models.load_model(args.model_folder)
mnist = build_dataset()
_, _, _, _, x_val, y_val= mnist
predictions = highway.predict(x_val)
num_digits = 10
y_val = tf.keras.utils.to_categorical(y_val, num_digits)
print('---------------------Prediction Result: -------------------')
print('Output Softmax: {}'.format(np.argmax(predictions[args.image_index]), axis=1))
print('This image belongs to class: {}'.format(np.argmax(y_val[args.image_index]), axis=1))