-
Notifications
You must be signed in to change notification settings - Fork 40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Test data is not correctly loaded #4
Comments
@t0024 I also find this problem,and this problem creates some problems in the file test.py. My email is [email protected],would you please disscuss with me about these problems?Thank you! |
Indeed, please take a look the following code. def test_NA(testing_data, input_x, input_p1, input_p2, s, p, dropout_keep_prob, datamanager, sess, num_epoch, save_dir="tmp_log"):
relation_instances = [] #list of [ent_pair, pred_relation_type, score, gold_relation_type]
for test in testing_data:
gold_id = testing_data[test][0].relation.id
#check
for snt in testing_data[test]:
if snt.relation.id!=gold_id:
print("invalid bag : ", gold_id, ", ", snt.relation.id)
gold_id = -1
break
if gold_id<0: continue
x_test = datamanager.generate_x(testing_data[test])
p1, p2 = datamanager.generate_p(testing_data[test])
y_test = datamanager.generate_y(testing_data[test])
scores, predictions = sess.run([s, p], {input_x: x_test, input_p1:p1, input_p2:p2, dropout_keep_prob: 1.0})
max_pro = 0.0
prediction = -1
for score in scores:
score = np.exp(score)
score = score/score.sum(axis=0)
score[0] = 0 #in NYT there is not NA prediction
pred = np.argmax(score)
pro = score[pred]
if pro > max_pro:
max_pro = pro
prediction =pred
relation_instances+=[[test, prediction, max_pro, gold_id]]
relation_instances = sorted(relation_instances, key=lambda x:x[2])
relation_instances.reverse() #relation instances sorted by scores
acc_pos= 0
pos_total=0
neg_size=0
for r in relation_instances:
if r[3]!=0:
pos_total+=1
if r[3]==r[1]: acc_pos+=1
else:
neg_size+=1
#accuracy
print "neg_gold_relation : ", neg_size
print "pos_gold_relation : ", pos_total
print "correct_pos_pred : ", acc_pos
print "Accuracy(over_pos_relation) : ", acc_pos/float(pos_total)
#R-C curver
print("Top_N\tPrecision\tRecall")
c = 0
for n,r in enumerate(relation_instances):
if r[3]==r[1]: c+=1
if n+1 in [1, 5, 10, 20, 50, 100, 200] or (n+1)%1000==0:
print "%d\t%f\t%f"%(n+1, c/float(n+1), c/float(pos_total)) |
why you set score[0] = 0? If the label is NA,based on your demo,the prediction won't be NA,how you calculate fp? |
|
ResCNN_RelationExtraction/ResidualCNN9/util/DataManager.py
Line 96 in 5449aba
The test data should be collected for each entity pair in the map, self.bags_test (map<str, list>); however, the keys are set in a different way when a key is checked(L96). This bug effects test.py and test scores.
CURRENT : if entity1+"\t"+entity2 not in self.bags_test:
SUGGESTED : if entity1+" "+entity2 not in self.bags_test:
The text was updated successfully, but these errors were encountered: