-
Notifications
You must be signed in to change notification settings - Fork 42
/
Copy pathinit_graph.py
30 lines (23 loc) · 939 Bytes
/
init_graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import tensorflow as tf
import argparse
from os.path import dirname, join
def main(args):
# dummy dataset
tf.data.TFRecordDataset("").make_initializable_iterator()
# import graph
print("importing {}".format(args.graph))
tf.train.import_meta_graph(args.graph)
saver = tf.train.Saver()
checkpoint = join(dirname(args.graph),"model.ckpt")
with tf.Session() as sess:
print("initializing global variables")
sess.run(tf.global_variables_initializer())
print("writing {}".format(checkpoint))
saver.save(sess, checkpoint)
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='creates 0 checkpoint from a directory containing meta graph')
parser.add_argument('graph', type=str,
help='directory containing the source model (must contain graph.meta and checkpoint files)')
args = parser.parse_args()
main(args)