Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test data is not correctly loaded #4

Open
t0024 opened this issue Jan 9, 2018 · 4 comments
Open

Test data is not correctly loaded #4

t0024 opened this issue Jan 9, 2018 · 4 comments

Comments

@t0024
Copy link

t0024 commented Jan 9, 2018

if data[0]+"\t"+data[1] not in self.bags_test:

The test data should be collected for each entity pair in the map, self.bags_test (map<str, list>); however, the keys are set in a different way when a key is checked(L96). This bug effects test.py and test scores.

CURRENT : if entity1+"\t"+entity2 not in self.bags_test:
SUGGESTED : if entity1+" "+entity2 not in self.bags_test:

@ljxalpha
Copy link

ljxalpha commented May 1, 2018

@t0024 I also find this problem,and this problem creates some problems in the file test.py. My email is [email protected],would you please disscuss with me about these problems?Thank you!

@t0024
Copy link
Author

t0024 commented May 1, 2018

Indeed, please take a look the following code.

def test_NA(testing_data, input_x, input_p1, input_p2, s, p, dropout_keep_prob, datamanager, sess, num_epoch, save_dir="tmp_log"):
    relation_instances = [] #list of [ent_pair, pred_relation_type, score, gold_relation_type]

    for test in testing_data:
        gold_id = testing_data[test][0].relation.id
        #check
        for snt in testing_data[test]:
          if snt.relation.id!=gold_id:
            print("invalid bag : ", gold_id, ", ", snt.relation.id)
            gold_id = -1
            break
        if gold_id<0: continue 

        x_test = datamanager.generate_x(testing_data[test])
        p1, p2 = datamanager.generate_p(testing_data[test])
        y_test = datamanager.generate_y(testing_data[test])
        scores, predictions = sess.run([s, p], {input_x: x_test, input_p1:p1, input_p2:p2, dropout_keep_prob: 1.0})
       
        max_pro = 0.0
        prediction = -1
        for score in scores:
            score = np.exp(score)
            score = score/score.sum(axis=0)
            score[0] = 0 #in NYT there is not NA prediction
            pred =  np.argmax(score)
            pro = score[pred]
            if pro > max_pro:
                max_pro = pro
                prediction =pred
        
        relation_instances+=[[test, prediction, max_pro, gold_id]]
    
    relation_instances = sorted(relation_instances, key=lambda x:x[2])
    relation_instances.reverse() #relation instances sorted by scores
    
    acc_pos= 0
    pos_total=0
    neg_size=0
    for r in relation_instances:
      if r[3]!=0:
        pos_total+=1
        if r[3]==r[1]: acc_pos+=1
      else:
        neg_size+=1
    #accuracy
    print "neg_gold_relation : ", neg_size
    print "pos_gold_relation : ", pos_total
    print "correct_pos_pred : ", acc_pos
    print "Accuracy(over_pos_relation) : ", acc_pos/float(pos_total)
    
    #R-C curver
    print("Top_N\tPrecision\tRecall")
    c = 0
    for n,r in enumerate(relation_instances):
      if r[3]==r[1]: c+=1
    
      if n+1 in [1, 5, 10, 20, 50, 100, 200] or (n+1)%1000==0:
        print "%d\t%f\t%f"%(n+1, c/float(n+1), c/float(pos_total)) 

@13Lingzi
Copy link

13Lingzi commented Nov 6, 2018

Indeed, please take a look the following code.

def test_NA(testing_data, input_x, input_p1, input_p2, s, p, dropout_keep_prob, datamanager, sess, num_epoch, save_dir="tmp_log"):
    relation_instances = [] #list of [ent_pair, pred_relation_type, score, gold_relation_type]

    for test in testing_data:
        gold_id = testing_data[test][0].relation.id
        #check
        for snt in testing_data[test]:
          if snt.relation.id!=gold_id:
            print("invalid bag : ", gold_id, ", ", snt.relation.id)
            gold_id = -1
            break
        if gold_id<0: continue 

        x_test = datamanager.generate_x(testing_data[test])
        p1, p2 = datamanager.generate_p(testing_data[test])
        y_test = datamanager.generate_y(testing_data[test])
        scores, predictions = sess.run([s, p], {input_x: x_test, input_p1:p1, input_p2:p2, dropout_keep_prob: 1.0})
       
        max_pro = 0.0
        prediction = -1
        for score in scores:
            score = np.exp(score)
            score = score/score.sum(axis=0)
            score[0] = 0 #in NYT there is not NA prediction
            pred =  np.argmax(score)
            pro = score[pred]
            if pro > max_pro:
                max_pro = pro
                prediction =pred
        
        relation_instances+=[[test, prediction, max_pro, gold_id]]
    
    relation_instances = sorted(relation_instances, key=lambda x:x[2])
    relation_instances.reverse() #relation instances sorted by scores
    
    acc_pos= 0
    pos_total=0
    neg_size=0
    for r in relation_instances:
      if r[3]!=0:
        pos_total+=1
        if r[3]==r[1]: acc_pos+=1
      else:
        neg_size+=1
    #accuracy
    print "neg_gold_relation : ", neg_size
    print "pos_gold_relation : ", pos_total
    print "correct_pos_pred : ", acc_pos
    print "Accuracy(over_pos_relation) : ", acc_pos/float(pos_total)
    
    #R-C curver
    print("Top_N\tPrecision\tRecall")
    c = 0
    for n,r in enumerate(relation_instances):
      if r[3]==r[1]: c+=1
    
      if n+1 in [1, 5, 10, 20, 50, 100, 200] or (n+1)%1000==0:
        print "%d\t%f\t%f"%(n+1, c/float(n+1), c/float(pos_total)) 

why you set score[0] = 0? If the label is NA,based on your demo,the prediction won't be NA,how you calculate fp?

@13Lingzi
Copy link

13Lingzi commented Nov 6, 2018

Indeed, please take a look the following code.

def test_NA(testing_data, input_x, input_p1, input_p2, s, p, dropout_keep_prob, datamanager, sess, num_epoch, save_dir="tmp_log"):
    relation_instances = [] #list of [ent_pair, pred_relation_type, score, gold_relation_type]

    for test in testing_data:
        gold_id = testing_data[test][0].relation.id
        #check
        for snt in testing_data[test]:
          if snt.relation.id!=gold_id:
            print("invalid bag : ", gold_id, ", ", snt.relation.id)
            gold_id = -1
            break
        if gold_id<0: continue 

        x_test = datamanager.generate_x(testing_data[test])
        p1, p2 = datamanager.generate_p(testing_data[test])
        y_test = datamanager.generate_y(testing_data[test])
        scores, predictions = sess.run([s, p], {input_x: x_test, input_p1:p1, input_p2:p2, dropout_keep_prob: 1.0})
       
        max_pro = 0.0
        prediction = -1
        for score in scores:
            score = np.exp(score)
            score = score/score.sum(axis=0)
            score[0] = 0 #in NYT there is not NA prediction
            pred =  np.argmax(score)
            pro = score[pred]
            if pro > max_pro:
                max_pro = pro
                prediction =pred
        
        relation_instances+=[[test, prediction, max_pro, gold_id]]
    
    relation_instances = sorted(relation_instances, key=lambda x:x[2])
    relation_instances.reverse() #relation instances sorted by scores
    
    acc_pos= 0
    pos_total=0
    neg_size=0
    for r in relation_instances:
      if r[3]!=0:
        pos_total+=1
        if r[3]==r[1]: acc_pos+=1
      else:
        neg_size+=1
    #accuracy
    print "neg_gold_relation : ", neg_size
    print "pos_gold_relation : ", pos_total
    print "correct_pos_pred : ", acc_pos
    print "Accuracy(over_pos_relation) : ", acc_pos/float(pos_total)
    
    #R-C curver
    print("Top_N\tPrecision\tRecall")
    c = 0
    for n,r in enumerate(relation_instances):
      if r[3]==r[1]: c+=1
    
      if n+1 in [1, 5, 10, 20, 50, 100, 200] or (n+1)%1000==0:
        print "%d\t%f\t%f"%(n+1, c/float(n+1), c/float(pos_total)) 

why you set score[0] = 0? If the label is NA,based on your demo,the prediction won't be NA,how you calculate fp?

@t0024

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants