forked from thunlp/FewRel
-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathtest_demo.py
22 lines (18 loc) · 966 Bytes
/
test_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import models
from fewshot_re_kit.data_loader import JSONFileDataLoader
from fewshot_re_kit.framework import FewShotREFramework
from fewshot_re_kit.sentence_encoder import CNNSentenceEncoder
from models.proto import Proto
from models.snowball import Snowball
max_length = 40
train_data_loader = JSONFileDataLoader('./data/train.json', './data/glove.6B.50d.json', max_length=max_length)
val_data_loader = JSONFileDataLoader('./data/val.json', './data/glove.6B.50d.json', max_length=max_length)
test_data_loader = JSONFileDataLoader('./data/test.json', './data/glove.6B.50d.json', max_length=max_length)
framework = FewShotREFramework(train_data_loader, val_data_loader, test_data_loader)
sentence_encoder = CNNSentenceEncoder(train_data_loader.word_vec_mat, max_length)
model = Proto(sentence_encoder).cuda()
acc = 0
for i in range(5):
acc += framework.eval(model, 4, 5, 5, 100, 3000, ckpt='checkpoint/proto.pth.tar')
acc /= 5.0
print("ACC: {}".format(acc))