-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgrid.py
43 lines (39 loc) · 1.12 KB
/
grid.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from grpc_client import main
import numpy as np
import pickle
import argparse
def search(tsv,pkl, threshold=None):
with open(tsv) as f:
tfs = []
for line in f:
source, *targets = line.strip().split("\t")
if any(source == t for t in targets):
tfs.append(1)
else:
tfs.append(0)
tfs = np.array(tfs)
with open(pkl,"rb") as f:
es = pickle.load(f)
ps = [e["tf_logit"] for e in es]
ps = np.array(ps)
assert len(tfs)==len(ps)
acc=0
best=0
for i in range(1001):
thres = i*0.001
p = ps>thres
r = int(sum(p==tfs))
a = r/len(tfs)
if a > acc:
acc = a
best= thres
return best, acc
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--valid_file", type=str, required=True)
parser.add_argument("--valid_pkl", type=str, required=True)
args = parser.parse_args()
thres, acc = search(args.valid_file, args.valid_pkl)
print(args.valid_file)
print("Threshold:", thres)
print("Accuarcy:", acc)