-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathkge_model.py
executable file
·115 lines (91 loc) · 3.59 KB
/
kge_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch
import torch.nn as nn
class KGEModel(nn.Module):
def __init__(self, args, model_name):
super(KGEModel, self).__init__()
self.model_name = model_name
self.embedding_range = torch.Tensor([(args.gamma + args.epsilon) / args.hidden_dim])
self.gamma = nn.Parameter(
torch.Tensor([args.gamma]),
requires_grad=False
)
def forward(self, sample, relation_embedding, entity_embedding, neg=True):
if not neg:
head = torch.index_select(
entity_embedding,
dim=0,
index=sample[:, 0]
).unsqueeze(1)
relation = torch.index_select(
relation_embedding,
dim=0,
index=sample[:, 1]
).unsqueeze(1)
tail = torch.index_select(
entity_embedding,
dim=0,
index=sample[:, 2]
).unsqueeze(1)
else:
head_part, tail_part = sample
batch_size = head_part.shape[0]
head = torch.index_select(
entity_embedding,
dim=0,
index=head_part[:, 0]
).unsqueeze(1)
relation = torch.index_select(
relation_embedding,
dim=0,
index=head_part[:, 1]
).unsqueeze(1)
if tail_part == None:
tail = entity_embedding.unsqueeze(0)
else:
negative_sample_size = tail_part.size(1)
tail = torch.index_select(
entity_embedding,
dim=0,
index=tail_part.view(-1)
).view(batch_size, negative_sample_size, -1)
model_func = {
'TransE': self.TransE,
'DistMult': self.DistMult,
'ComplEx': self.ComplEx,
'RotatE': self.RotatE,
}
score = model_func[self.model_name](head, relation, tail)
return score
def TransE(self, head, relation, tail):
score = (head + relation) - tail
score = self.gamma.item() - torch.norm(score, p=1, dim=2)
return score
def DistMult(self, head, relation, tail):
score = (head * relation) * tail
score = score.sum(dim = 2)
return score
def ComplEx(self, head, relation, tail):
re_head, im_head = torch.chunk(head, 2, dim=2)
re_relation, im_relation = torch.chunk(relation, 2, dim=2)
re_tail, im_tail = torch.chunk(tail, 2, dim=2)
re_score = re_head * re_relation - im_head * im_relation
im_score = re_head * im_relation + im_head * re_relation
score = re_score * re_tail + im_score * im_tail
score = score.sum(dim = 2)
return score
def RotatE(self, head, relation, tail):
pi = 3.14159265358979323846
re_head, im_head = torch.chunk(head, 2, dim=2)
re_tail, im_tail = torch.chunk(tail, 2, dim=2)
#Make phases of relations uniformly distributed in [-pi, pi]
phase_relation = relation/(self.embedding_range.item()/pi)
re_relation = torch.cos(phase_relation)
im_relation = torch.sin(phase_relation)
re_score = re_head * re_relation - im_head * im_relation
im_score = re_head * im_relation + im_head * re_relation
re_score = re_score - re_tail
im_score = im_score - im_tail
score = torch.stack([re_score, im_score], dim = 0)
score = score.norm(dim = 0)
score = self.gamma.item() - score.sum(dim = 2)
return score