-
Notifications
You must be signed in to change notification settings - Fork 9
/
dis-pu.py
58 lines (52 loc) · 1.6 KB
/
dis-pu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#!/usr/bin/env python
#-*- coding:utf-8 _*-
"""
@author:liruihui
@file: dis-pu.py
@time: 2021/04/09
@contact: [email protected]
@github: https://liruihui.github.io/
@description:
"""
import tensorflow as tf
from DisPU.model import Model
from DisPU.configs import FLAGS
from datetime import datetime
import os
import logging
import pprint
pp = pprint.PrettyPrinter()
def run():
if FLAGS.phase=='train':
FLAGS.train_file = os.path.join(FLAGS.data_dir, 'train/PUGAN_poisson_256_poisson_1024.h5')
print('train_file:',FLAGS.train_file)
if not FLAGS.restore:
current_time = datetime.now().strftime("%Y%m%d-%H%M")
FLAGS.log_dir = os.path.join(FLAGS.log_dir,current_time)
try:
os.makedirs(FLAGS.log_dir)
except os.error:
pass
else:
FLAGS.log_dir = os.path.join(os.getcwd(),'model')
FLAGS.test_data = os.path.join(FLAGS.data_dir, 'test/*.xyz')
FLAGS.out_folder = os.path.join(FLAGS.data_dir,'test/output')
if not os.path.exists(FLAGS.out_folder):
os.makedirs(FLAGS.out_folder)
print('test_data:',FLAGS.test_data)
print('checkpoints:',FLAGS.log_dir)
pp.pprint(FLAGS)
# open session
run_config = tf.ConfigProto()
run_config.gpu_options.allow_growth = True
with tf.Session(config=run_config) as sess:
model = Model(FLAGS,sess)
if FLAGS.phase == 'train':
model.train()
else:
model.test()
def main(unused_argv):
run()
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
tf.app.run()