-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
306 lines (226 loc) · 11.8 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
import os
import random
import time
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from config import Config
from create_batch import Corpus
from models import SpKBGATModified, SpKBGATConvOnly
from preprocess import init_embeddings, build_data
from utils import save_model, save_object
args = Config()
args.load_config()
device = "cuda" if args.cuda else "cpu"
def load_data(args):
train_data, validation_data, test_data, entity2id, relation2id, headTailSelector, unique_entities_train = build_data(
args.data_folder, is_unweigted=False, directed=True)
if args.pretrained_emb:
entity_embeddings, relation_embeddings = init_embeddings(os.path.join(args.data_folder, 'entity2vec.txt'),
os.path.join(args.data_folder, 'relation2vec.txt'))
print("Initialised relations and entities from TransE")
else:
entity_embeddings = np.random.randn(
len(entity2id), args.embedding_size)
relation_embeddings = np.random.randn(
len(relation2id), args.embedding_size)
print("Initialised relations and entities randomly")
corpus = Corpus(args, train_data, validation_data, test_data, entity2id, relation2id, headTailSelector,
args.batch_size_gat, args.valid_invalid_ratio_gat, unique_entities_train, args.get_2hop)
return corpus, torch.FloatTensor(entity_embeddings), torch.FloatTensor(relation_embeddings)
Corpus_, entity_embeddings, relation_embeddings = load_data(args)
node_neighbors_2hop = Corpus_.node_neighbors_2hop
print("Initial entity dimensions {} , relation dimensions {}".format(
entity_embeddings.size(), relation_embeddings.size()))
print("1. Created Corpus Successfully !")
def batch_gat_loss(gat_loss_func, train_indices, entity_embed, relation_embed):
len_pos_triples = int(
train_indices.shape[0] / (int(args.valid_invalid_ratio_gat) + 1))
pos_triples = train_indices[:len_pos_triples]
neg_triples = train_indices[len_pos_triples:]
pos_triples = pos_triples.repeat(int(args.valid_invalid_ratio_gat), 1)
source_embeds = entity_embed[pos_triples[:, 0]]
relation_embeds = relation_embed[pos_triples[:, 1]]
tail_embeds = entity_embed[pos_triples[:, 2]]
x = source_embeds + relation_embeds - tail_embeds
pos_norm = torch.norm(x, p=1, dim=1)
source_embeds = entity_embed[neg_triples[:, 0]]
relation_embeds = relation_embed[neg_triples[:, 1]]
tail_embeds = entity_embed[neg_triples[:, 2]]
x = source_embeds + relation_embeds - tail_embeds
neg_norm = torch.norm(x, p=1, dim=1)
y = -torch.ones(int(args.valid_invalid_ratio_gat) * len_pos_triples).cuda()
loss = gat_loss_func(pos_norm, neg_norm, y)
return loss
def train_gat(args):
# Creating the gat model here.
####################################
print("Defining model")
print(
"\nModel type -> GAT layer with {} heads used , Initital Embeddings training".format(args.nheads_GAT[0]))
model_gat = SpKBGATModified(entity_embeddings, relation_embeddings, args.entity_out_dim, args.entity_out_dim,
args.drop_GAT, args.alpha, args.nheads_GAT)
if args.cuda:
model_gat.cuda()
optimizer = torch.optim.Adam(
model_gat.parameters(), lr=args.lr, weight_decay=args.weight_decay_gat)
scheduler = torch.optim.lr_scheduler.StepLR(
optimizer, step_size=500, gamma=0.5, last_epoch=-1)
gat_loss_func = nn.MarginRankingLoss(margin=args.margin)
current_batch_2hop_indices = torch.tensor([])
if (args.use_2hop):
current_batch_2hop_indices = Corpus_.get_batch_nhop_neighbors_all(args,
Corpus_.unique_entities_train,
node_neighbors_2hop)
if args.cuda:
current_batch_2hop_indices = Variable(
torch.LongTensor(current_batch_2hop_indices)).cuda()
else:
current_batch_2hop_indices = Variable(
torch.LongTensor(current_batch_2hop_indices))
epoch_losses = [] # losses of all epochs
print("Number of epochs {}".format(args.epochs_gat))
for epoch in range(args.epochs_gat):
if args.print_console:
print("\nepoch-> ", epoch)
random.shuffle(Corpus_.train_triples)
Corpus_.train_indices = np.array(
list(Corpus_.train_triples)).astype(np.int32)
model_gat.train() # getting in training mode
start_time = time.time()
epoch_loss = []
if len(Corpus_.train_indices) % args.batch_size_gat == 0:
num_iters_per_epoch = len(
Corpus_.train_indices) // args.batch_size_gat
else:
num_iters_per_epoch = (len(Corpus_.train_indices) // args.batch_size_gat) + 1
for iters in range(num_iters_per_epoch):
start_time_iter = time.time()
train_indices, train_values = Corpus_.get_iteration_batch(iters)
if args.cuda:
train_indices = Variable(
torch.LongTensor(train_indices)).cuda()
train_values = Variable(torch.FloatTensor(train_values)).cuda()
else:
train_indices = Variable(torch.LongTensor(train_indices))
train_values = Variable(torch.FloatTensor(train_values))
# forward pass
entity_embed, relation_embed = model_gat(
Corpus_, Corpus_.train_adj_matrix, train_indices, current_batch_2hop_indices)
optimizer.zero_grad()
loss = batch_gat_loss(
gat_loss_func, train_indices, entity_embed, relation_embed)
loss.backward()
optimizer.step()
epoch_loss.append(loss.data.item())
end_time_iter = time.time()
if args.print_console:
print("Iteration-> {0} , Iteration_time-> {1:.4f} , Iteration_loss {2:.4f}".format(
iters, end_time_iter - start_time_iter, loss.data.item()))
scheduler.step()
if args.print_console:
print("Epoch {} , average loss {} , epoch_time {}".format(
epoch, sum(epoch_loss) / len(epoch_loss), time.time() - start_time))
epoch_losses.append(sum(epoch_loss) / len(epoch_loss))
if epoch >= args.epochs_gat - 1:
save_model(model_gat, name="gat", epoch=epoch)
save_object(epoch_losses, output=args.output_folder, name="loss_gat")
print("2. Train Encoder Successfully !")
def train_conv(args):
# Creating convolution model here.
####################################
print("Defining model")
model_gat = SpKBGATModified(entity_embeddings, relation_embeddings, args.entity_out_dim, args.entity_out_dim,
args.drop_GAT, args.alpha, args.nheads_GAT)
print("Only Conv model trained")
model_conv = SpKBGATConvOnly(entity_embeddings, relation_embeddings, args.entity_out_dim, args.entity_out_dim,
args.drop_GAT, args.drop_conv, args.alpha, args.alpha_conv,
args.nheads_GAT, args.out_channels)
if args.cuda:
model_conv.cuda()
model_gat.cuda()
print("Loading GAT encoder")
folder = "{output}/{dataset}".format(output=args.output_folder, dataset=args.dataset)
if args.save_gdrive:
folder = args.drive_folder
model_name = "{folder}/{dataset}_{device}_{name}_{epoch}.pt".format(folder=folder, dataset=args.dataset,
device=args.device, name="gat",
epoch=args.epochs_gat - 1)
model_gat.load_state_dict(torch.load(model_name), strict=False)
model_conv.final_entity_embeddings = model_gat.final_entity_embeddings
model_conv.final_relation_embeddings = model_gat.final_relation_embeddings
Corpus_.batch_size = args.batch_size_conv
Corpus_.invalid_valid_ratio = int(args.valid_invalid_ratio_conv)
optimizer = torch.optim.Adam(
model_conv.parameters(), lr=args.lr, weight_decay=args.weight_decay_conv)
scheduler = torch.optim.lr_scheduler.StepLR(
optimizer, step_size=25, gamma=0.5, last_epoch=-1)
margin_loss = torch.nn.SoftMarginLoss()
epoch_losses = [] # losses of all epochs
print("Number of epochs {}".format(args.epochs_conv))
for epoch in range(args.epochs_conv):
if args.print_console:
print("\nepoch-> ", epoch)
random.shuffle(Corpus_.train_triples)
Corpus_.train_indices = np.array(
list(Corpus_.train_triples)).astype(np.int32)
model_conv.train() # getting in training mode
start_time = time.time()
epoch_loss = []
if len(Corpus_.train_indices) % args.batch_size_conv == 0:
num_iters_per_epoch = len(
Corpus_.train_indices) // args.batch_size_conv
else:
num_iters_per_epoch = (
len(Corpus_.train_indices) // args.batch_size_conv) + 1
for iters in range(num_iters_per_epoch):
start_time_iter = time.time()
train_indices, train_values = Corpus_.get_iteration_batch(iters)
if args.cuda:
train_indices = Variable(
torch.LongTensor(train_indices)).cuda()
train_values = Variable(torch.FloatTensor(train_values)).cuda()
else:
train_indices = Variable(torch.LongTensor(train_indices))
train_values = Variable(torch.FloatTensor(train_values))
preds = model_conv(
Corpus_, Corpus_.train_adj_matrix, train_indices)
optimizer.zero_grad()
loss = margin_loss(preds.view(-1), train_values.view(-1))
loss.backward()
optimizer.step()
epoch_loss.append(loss.data.item())
end_time_iter = time.time()
if args.print_console:
print("Iteration-> {0} , Iteration_time-> {1:.4f} , Iteration_loss {2:.4f}".format(
iters, end_time_iter - start_time_iter, loss.data.item()))
scheduler.step()
if args.print_console:
print("Epoch {} , average loss {} , epoch_time {}".format(
epoch, sum(epoch_loss) / len(epoch_loss), time.time() - start_time))
epoch_losses.append(sum(epoch_loss) / len(epoch_loss))
if epoch >= args.epochs_conv - 1:
save_model(model_conv, name="conv", epoch=epoch)
save_object(epoch_losses, output=args.output_folder, name="loss_conv")
print("3. Train Decoder Successfully !")
def evaluate_conv(args, unique_entities):
model_conv = SpKBGATConvOnly(entity_embeddings, relation_embeddings, args.entity_out_dim, args.entity_out_dim,
args.drop_GAT, args.drop_conv, args.alpha, args.alpha_conv,
args.nheads_GAT, args.out_channels)
folder = "{output}/{dataset}".format(output=args.output_folder, dataset=args.dataset)
if args.save_gdrive:
folder = args.drive_folder
model_name = "{folder}/{dataset}_{device}_{name}_{epoch}.pt".format(folder=folder, dataset=args.dataset,
device=device, name="conv",
epoch=args.epochs_conv - 1)
model_conv.load_state_dict(torch.load(model_name), strict=False)
if args.cuda:
model_conv.cuda()
model_conv.eval()
with torch.no_grad():
Corpus_.get_validation_pred(args, model_conv, unique_entities)
print("4. Evaluation Successfully !")
train_gat(args)
train_conv(args)
evaluate_conv(args, Corpus_.unique_entities_train)