-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathval.py
64 lines (56 loc) · 2.28 KB
/
val.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from timecyclegan.models import get_task
from timecyclegan.util.argparser import parse_args
from timecyclegan.util.logging import log_command
from timecyclegan.evaluation.compute_metrics import compute_metrics
from test import test, define_output_dir
def validate(**kwargs):
"""Train and validate a model: train > inference > metric computation"""
def get_fake_dir(dir_name_appendix=""):
"""Wrapper for define_output_dir to reduce code duplicates"""
return define_output_dir(
model_name=kwargs["model_name"],
output_root=kwargs["test_output_dir"],
dir_name_appendix=dir_name_appendix,
)
def compute_val_metrics(unpaired=False, source_to_target=False):
"""Compute metrics depending on task"""
if unpaired:
if source_to_target:
real_dir = kwargs["test_target_dir"]
fake_dir = get_fake_dir("source_to_target")
unpaired_source_dir = kwargs["test_source_dir"]
else:
real_dir = kwargs["test_source_dir"]
fake_dir = get_fake_dir("target_to_source")
unpaired_source_dir = kwargs["test_target_dir"]
else:
real_dir = kwargs["test_target_dir"]
fake_dir = get_fake_dir()
unpaired_source_dir = None
return compute_metrics(
real_dir=real_dir,
fake_dir=fake_dir,
model_name=kwargs["model_name"],
use_gpu=(kwargs["gpu"] >= 0),
height=kwargs["image_height"],
width=kwargs["image_width"],
unpaired_source_dir=unpaired_source_dir,
)
print("*** INFERENCING MODEL ***")
if get_task(kwargs["model_type"]) != "unpaired":
test(**kwargs)
else:
kwargs["test_unpaired_target_to_source"] = False
test(**kwargs)
kwargs["test_unpaired_target_to_source"] = True
test(**kwargs)
print("*** COMPUTING METRICS ***")
if get_task(kwargs["model_type"]) != "unpaired":
compute_val_metrics()
else:
compute_val_metrics(unpaired=True, source_to_target=True)
compute_val_metrics(unpaired=True, source_to_target=False)
if __name__ == '__main__':
val_kwargs = parse_args(mode="val")
log_command()
validate(**val_kwargs)