-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluation.py
161 lines (144 loc) · 5.1 KB
/
evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
from collections import defaultdict
import argparse
import os
import wandb
from utils.data_handling import *
from analysis.f1_score import compute_score
def evaluate(
data,
predictions,
prediction_mode,
aliases,
num_aliases=-1,
filter_mutable_with_one_ans=False,
):
# compute F1 as max across any alias for any answer for the most recent, most frequent, or specific-year answer
qa_targets, qa_predictions = defaultdict(list), defaultdict(list)
num_empty = 0
for query_id, query in data.items():
relation = query["relation"]
target = list()
if (
query["type"] == "mutable"
and filter_mutable_with_one_ans
and len(query["answer"]) == 1
):
continue
for answer in query["answer"]:
if answer["wikidata_id"] in aliases:
answer_aliases = aliases[answer["wikidata_id"]]
if num_aliases > -1:
answer_aliases = answer_aliases[:num_aliases]
target += answer_aliases
target.append(answer["name"])
if target is None:
continue
prediction = get_prediction(predictions, query_id, prediction_mode)
if not len(prediction["answer"]):
num_empty += 1
# print("Warning: the prediction for query='{}' was empty.".format(query))
continue
qa_targets[relation].append(
{
"answers": {"answer_start": [0] * len(target), "text": target},
"id": query_id,
}
)
qa_targets["all"].append(
{
"answers": {"answer_start": [0] * len(target), "text": target},
"id": query_id,
}
)
qa_predictions[relation].append(
{"prediction_text": prediction["answer"], "id": query_id}
)
qa_predictions["all"].append(
{"prediction_text": prediction["answer"], "id": query_id}
)
print("Evaluating on {} datapoints".format(len(qa_targets["all"])))
print("Num empty", num_empty)
for rel in qa_targets.keys():
df, scores = compute_score(
predictions=qa_predictions[rel], references=qa_targets[rel]
)
yield rel, df, {"n_datapoints": len(qa_targets["all"]), **scores}
def load_queries(data_path):
unique_queries = dict()
queries = load_dataset(data_path, split="train")
for query in queries:
query_id = "_".join(query["id"].split("_")[:2])
if query_id not in unique_queries and len(query["answer"]):
unique_queries[query_id] = query
return unique_queries
def load_aliases(data_path):
all_aliases = dict()
aliases = load_dataset(data_path, split="train")
for qid, al in aliases[0].items():
all_aliases[qid] = al
return all_aliases
def main(args):
experiment_dir = os.path.join(args.output_dir, args.exp_name)
if not os.path.exists(experiment_dir):
os.makedirs(experiment_dir, exist_ok=True)
project_name = "lm_mutability_preds_eval"
wandb.init(
project=project_name,
name="(eval) " + args.exp_name,
config=args,
)
data = load_queries(args.data_path)
aliases = load_aliases(args.aliases_path)
predictions = load_predictions(args.predictions_path)
with open(os.path.join(experiment_dir, f"metrics.jsonl"), "w") as fhandle:
for rel, df, scores in evaluate(
data,
predictions,
args.prediction_mode,
aliases,
num_aliases=args.num_aliases,
filter_mutable_with_one_ans=args.filter_mutable_with_one_ans,
):
df.to_json(os.path.join(experiment_dir, f"{rel}_results_per_example.json"))
wandb.log({k: v for k, v in scores.items() if not isinstance(v, list)})
print(f"{rel}: ", scores["ave_f1"])
data = {rel: scores["ave_f1"]}
fhandle.write("{}\n".format(json.dumps(data)))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Evaluation")
parser.add_argument(
"--data_path",
type=str,
default="coastalcph/fm_queries",
help="Path to data",
)
parser.add_argument(
"--aliases_path",
type=str,
default="coastalcph/fm_aliases",
help="Path to data",
)
parser.add_argument("--predictions_path", type=str, help="Path to predictions")
parser.add_argument(
"--prediction_mode",
type=str,
default="first_token_probability",
choices=["perplexity", "first_token_probability"],
help="Which prediction do we evaluate",
)
parser.add_argument(
"--output_dir",
type=str,
default="output",
help="Dir where model outputs will be stored",
)
parser.add_argument(
"--num_aliases",
type=int,
default=-1,
help="Num aliases to use",
)
parser.add_argument("--filter_mutable_with_one_ans", action="store_true")
parser.add_argument("--exp_name", type=str, default="debug", help="Experiment name")
args = parser.parse_args()
main(args)