-
Notifications
You must be signed in to change notification settings - Fork 3
/
main.py
66 lines (53 loc) · 2.19 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""Main script.
Usage:
main.py <model> <dataset> <hparams>
"""
import os
import datetime
from docopt import docopt
from tools.JsonConfig import JsonConfig
from tools import code_examples as exp
import datasets
import models
if __name__ == "__main__":
args = docopt(__doc__)
model_name = args["<model>"]
hparams_name = args["<hparams>"]
dataset_name = args["<dataset>"]
assert model_name in models.model_dict, (
"`{}` is not supported, use `{}`".format(model_name, models.model_dict.keys()))
assert dataset_name in datasets.dataset_dict, (
"`{}` is not supported, use `{}`".format(dataset_name, datasets.dataset_dict.keys()))
assert os.path.exists(hparams_name), (
"Failed to find hparams josn `{}`".format(hparams_name))
hparams = JsonConfig(hparams_name)
dataset_class = datasets.dataset_dict[dataset_name]
model_class = models.model_dict[model_name]
date = str(datetime.datetime.now())
date = date[:date.rfind(":")].replace("-", "")\
.replace(":", "")\
.replace(" ", "_")
log_dir = os.path.join(hparams.Dir.log, "log_" + date)
print("log_dir:" + str(log_dir))
is_training = hparams.Infer.pre_trained == ""
data = dataset_class(hparams, is_training)
cond_dim, motion_dim = data.get_dims()
model = model_class(cond_dim, motion_dim, hparams)
if is_training:
if not os.path.exists(log_dir):
os.makedirs(log_dir)
model.build()
model.train(data, log_dir, hparams)
else:
model.build(chkpt_path=hparams.Infer.pre_trained)
# ---------------------
# Customize
# Some samples can be found in ./tools/code_examples.py
# ---------------------
# Generate result on test set
os.makedirs('./log_dir/synthesized', exist_ok=True)
output_list, motion_list = exp.generate_result_on_test_set(model, data, '/log_dir/synthesized')
# Evaluate KDE
kde_mean, kde_se = exp.evaluate_kde(output_list, motion_list, data)
with open('./log_dir/kde_result', 'w') as f:
print(f"kde_mean: {kde_mean}\nkde_se: {kde_se}", file=f)