-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmyGNN.py
120 lines (110 loc) · 4.85 KB
/
myGNN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import torch
import torch.nn as nn
import dgl
import numpy as np
from dgl.nn.pytorch import GraphConv
import dgl.function as fn
from dgl.nn.pytorch import edge_softmax, GATConv
from myconv import myGATConv
class DistMult(nn.Module):
def __init__(self, num_rel, dim):
super(DistMult, self).__init__()
self.W = nn.Parameter(torch.FloatTensor(size=(num_rel, dim, dim)))
nn.init.xavier_normal_(self.W, gain=1.414)
def forward(self, left_emb, right_emb, r_id):
thW = self.W[r_id]
left_emb = torch.unsqueeze(left_emb, 1)
right_emb = torch.unsqueeze(right_emb, 2)
return torch.bmm(torch.bmm(left_emb, thW), right_emb).squeeze()
class myDistMult(nn.Module):
def __init__(self, num_rel, dim):
super(myDistMult, self).__init__()
self.W = nn.ParameterList([nn.Parameter(torch.FloatTensor(size=(dim, dim))) for _ in range(num_rel)])
for i in range(num_rel):
nn.init.xavier_normal_(self.W[i], gain=1.414)
def forward(self, left_emb, right_emb, r_id):
thW = torch.stack([self.W[id] for id in r_id.astype(np.int64)])
left_emb = torch.unsqueeze(left_emb, 1)
right_emb = torch.unsqueeze(right_emb, 2)
return torch.bmm(torch.bmm(left_emb, thW), right_emb).squeeze()
class Dot(nn.Module):
def __init__(self):
super(Dot, self).__init__()
def forward(self, left_emb, right_emb, r_id):
left_emb = torch.unsqueeze(left_emb, 1)
right_emb = torch.unsqueeze(right_emb, 2)
return torch.bmm(left_emb, right_emb).squeeze()
class sepGAT(nn.Module):
def __init__(self,
g,
edge_dim,
num_etypes,
num_ntypes,
in_dims,
num_hidden,
num_classes,
num_layers,
heads,
activation,
feat_drop,
attn_drop,
negative_slope,
residual,
alpha,
decode='mydistmult'):
super(sepGAT, self).__init__()
self.g = g
self.num_layers = num_layers
self.gat_layers = nn.ModuleList()
self.activation = activation
self.fc_list = nn.ModuleList([nn.Linear(in_dim, num_hidden, bias=True) for in_dim in in_dims])
for fc in self.fc_list:
nn.init.xavier_normal_(fc.weight, gain=1.414)
# input projection (no residual)
self.gat_layers.append(myGATConv(edge_dim, num_etypes, num_ntypes,
in_dims, num_hidden, num_hidden, heads[0],
feat_drop, attn_drop, negative_slope, False, self.activation, alpha=alpha))
# hidden layers
for l in range(1, num_layers):
# due to multi-head, the in_dim = num_hidden * num_heads
self.gat_layers.append(myGATConv(edge_dim, num_etypes, num_ntypes,
in_dims, num_hidden * heads[l-1], num_hidden, heads[l],
feat_drop, attn_drop, negative_slope, residual, self.activation, alpha=alpha))
# output projection
self.gat_layers.append(myGATConv(edge_dim, num_etypes, num_ntypes,
in_dims, num_hidden * heads[-2], num_classes, heads[-1],
feat_drop, attn_drop, negative_slope, residual, None, alpha=alpha))
self.epsilon = torch.FloatTensor([1e-12]).cuda()
# self.epsilon = torch.FloatTensor([1e-12])
if decode == 'distmult':
self.decoder = DistMult(num_etypes, num_classes*(num_layers+2))
elif decode == 'dot':
self.decoder = Dot()
elif decode == 'mydistmult':
self.decoder = myDistMult(num_etypes, num_classes*(num_layers+2))
def l2_norm(self, x):
# This is an equivalent replacement for tf.l2_normalize, see https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/math/l2_normalize for more information.
return x / (torch.max(torch.norm(x, dim=1, keepdim=True), self.epsilon))
def forward(self, features_list, e_feat, e_count, left, right, mid):
h = []
n_count = []
index = 0
for fc, feature in zip(self.fc_list, features_list):
h.append(fc(feature))
n_count.append(len(feature))
h = torch.cat(h, 0)
emb = [self.l2_norm(h)]
res_attn = None
for l in range(self.num_layers):
h, res_attn = self.gat_layers[l](self.g, h, e_feat, e_count, n_count, res_attn=res_attn)
emb.append(self.l2_norm(h.mean(1)))
h = h.flatten(1)
# output projection
logits, _ = self.gat_layers[-1](self.g, h, e_feat, e_count, n_count, res_attn=res_attn)#None)
logits = logits.mean(1)
logits = self.l2_norm(logits)
emb.append(logits)
logits = torch.cat(emb, 1)
left_emb = logits[left]
right_emb = logits[right]
return self.decoder(left_emb, right_emb, mid)