-
Notifications
You must be signed in to change notification settings - Fork 22
/
predict.py
115 lines (86 loc) · 4.52 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import shutil
import subprocess
import argparse
from lib import (
get_retriever_address,
get_llm_server_address,
infer_source_target_prefix,
get_config_file_path_from_name_or_path,
)
def get_git_hash() -> str:
return subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("ascii").strip()
def main():
parser = argparse.ArgumentParser(description="Run configurable_inference on given config and dataset.")
parser.add_argument("experiment_name_or_path", type=str, help="experiment_name_or_path")
parser.add_argument("evaluation_path", type=str, help="evaluation_path")
parser.add_argument(
"--prediction-suffix", type=str, help="optional suffix for the prediction directory.", default=""
)
parser.add_argument("--dry-run", action="store_true", default=False, help="dry-run")
parser.add_argument("--skip-evaluation", type=str, default="", help="skip-evaluation")
parser.add_argument("--force", action="store_true", default=False, help="force predict if it exists")
parser.add_argument(
"--variable-replacements",
type=str,
help="json string for jsonnet local variable replacements.",
default="",
)
parser.add_argument("--silent", action="store_true", help="silent")
args = parser.parse_args()
config_filepath = get_config_file_path_from_name_or_path(args.experiment_name_or_path)
experiment_name = os.path.splitext(os.path.basename(config_filepath))[0]
prediction_directory = os.path.join("predictions", experiment_name + args.prediction_suffix)
os.makedirs(prediction_directory, exist_ok=True)
prediction_filename = os.path.splitext(os.path.basename(args.evaluation_path))[0]
prediction_filename = infer_source_target_prefix(config_filepath, args.evaluation_path) + prediction_filename
prediction_filepath = os.path.join(prediction_directory, "prediction__" + prediction_filename + ".json")
if os.path.exists(prediction_filepath) and not args.force:
from run import is_experiment_complete
metrics_file_path = os.path.join(prediction_directory, "evaluation_metrics__" + prediction_filename + ".json")
if is_experiment_complete(config_filepath, prediction_filepath, metrics_file_path, args.variable_replacements):
exit(f"The prediction_file_path {prediction_filepath} already exists and is complete. Pass --force.")
env_variables = {}
retriever_address = get_retriever_address()
env_variables["RETRIEVER_HOST"] = str(retriever_address["host"])
env_variables["RETRIEVER_PORT"] = str(retriever_address["port"])
llm_server_address = get_llm_server_address()
env_variables["LLM_SERVER_HOST"] = str(llm_server_address["host"])
env_variables["LLM_SERVER_PORT"] = str(llm_server_address["port"])
env_variables_str = " ".join([f"{key}={value}" for key, value in env_variables.items()]).strip()
predict_command = " ".join(
[
env_variables_str,
"python -m commaqa.inference.configurable_inference",
f"--config {config_filepath}",
f"--input {args.evaluation_path}",
f"--output {prediction_filepath}",
]
).strip()
if args.silent:
predict_command += " --silent"
if args.variable_replacements:
predict_command += f" --variable-replacements '{args.variable_replacements}'"
print(f"Run predict_command: \n{predict_command}\n")
if not args.dry_run:
subprocess.call(predict_command, shell=True)
# To be able to reproduce the same result:
git_hash_filepath = os.path.join(prediction_directory, "git_hash__" + prediction_filename + ".txt")
with open(git_hash_filepath, "w") as file:
file.write(get_git_hash())
# Again for reproducibility:
backup_config_filepath = os.path.join(prediction_directory, "config__" + prediction_filename + ".jsonnet")
shutil.copyfile(config_filepath, backup_config_filepath)
if not args.skip_evaluation:
evaluate_command = " ".join(["python evaluate.py", str(config_filepath), str(args.evaluation_path)]).strip()
print(f"Run evaluate_command: \n{evaluate_command}\n")
if not args.dry_run:
subprocess.call(evaluate_command, shell=True)
evaluate_command = " ".join(
["python evaluate.py", str(config_filepath), str(args.evaluation_path), "--official"]
).strip()
print(f"Run evaluate_command: \n{evaluate_command}\n")
if not args.dry_run:
subprocess.call(evaluate_command, shell=True)
if __name__ == "__main__":
main()