-
Notifications
You must be signed in to change notification settings - Fork 59
/
Copy pathload_data.py
30 lines (26 loc) · 1.42 KB
/
load_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class Data:
def __init__(self, data_dir="data/FB15k-237/", reverse=False):
self.train_data = self.load_data(data_dir, "train", reverse=reverse)
self.valid_data = self.load_data(data_dir, "valid", reverse=reverse)
self.test_data = self.load_data(data_dir, "test", reverse=reverse)
self.data = self.train_data + self.valid_data + self.test_data
self.entities = self.get_entities(self.data)
self.train_relations = self.get_relations(self.train_data)
self.valid_relations = self.get_relations(self.valid_data)
self.test_relations = self.get_relations(self.test_data)
self.relations = self.train_relations + [i for i in self.valid_relations \
if i not in self.train_relations] + [i for i in self.test_relations \
if i not in self.train_relations]
def load_data(self, data_dir, data_type="train", reverse=False):
with open("%s%s.txt" % (data_dir, data_type), "r") as f:
data = f.read().strip().split("\n")
data = [i.split() for i in data]
if reverse:
data += [[i[2], i[1]+"_reverse", i[0]] for i in data]
return data
def get_relations(self, data):
relations = sorted(list(set([d[1] for d in data])))
return relations
def get_entities(self, data):
entities = sorted(list(set([d[0] for d in data]+[d[2] for d in data])))
return entities