-
Notifications
You must be signed in to change notification settings - Fork 9
/
elephant_in_the_freezer.py
42 lines (34 loc) · 1.11 KB
/
elephant_in_the_freezer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import os
import glob
import tensorflow as tf
def load_graph(pb_dir):
pb_file = glob.glob(
os.path.join(pb_dir, '*.pb')
)[0]
with tf.gfile.GFile(pb_file, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def, name="model")
# Input: 'prefix/input_x'
# Binary Prediction: 'prefix/metrics/Sigmoid'
# Float Prediction: 'prefix/metrics/Cast'
names = [op.name for op in graph.get_operations()]
# [print(name) for name in names]
x = graph.get_tensor_by_name('model/input_x:0')
pred = graph.get_tensor_by_name('model/metrics/Sigmoid:0')
bin_pred = graph.get_tensor_by_name('model/metrics/Cast:0')
endpoints = {
'x': x,
'pred': pred,
'bin_pred': bin_pred
}
# print(x, pred, bin_pred)
return graph, endpoints
def unit_test():
pb_dir = './model/dt_0815_resume'
graph, endpoints = load_graph(pb_dir)
[print(endpoint) for endpoint in endpoints.items()]
print(endpoints['x'])
if __name__ == '__main__':
unit_test()