-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_lightning_unsupervised.py
207 lines (183 loc) · 8.21 KB
/
train_lightning_unsupervised.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
import dgl
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import dgl.nn.pytorch as dglnn
import dgl.function as fn
import time
import argparse
import tqdm
import glob
import os
from negative_sampler import NegativeSampler
from pytorch_lightning.metrics import Accuracy
from pytorch_lightning.callbacks import ModelCheckpoint, Callback
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from model import SAGE, compute_acc_unsupervised as compute_acc
from load_graph import load_reddit, inductive_split, load_ogb
class CrossEntropyLoss(nn.Module):
def forward(self, block_outputs, pos_graph, neg_graph):
with pos_graph.local_scope():
pos_graph.ndata['h'] = block_outputs
pos_graph.apply_edges(fn.u_dot_v('h', 'h', 'score'))
pos_score = pos_graph.edata['score']
with neg_graph.local_scope():
neg_graph.ndata['h'] = block_outputs
neg_graph.apply_edges(fn.u_dot_v('h', 'h', 'score'))
neg_score = neg_graph.edata['score']
score = th.cat([pos_score, neg_score])
label = th.cat([th.ones_like(pos_score), th.zeros_like(neg_score)]).long()
loss = F.binary_cross_entropy_with_logits(score, label.float())
return loss
class SAGELightning(LightningModule):
def __init__(self,
in_feats,
n_hidden,
n_classes,
n_layers,
activation,
dropout,
lr):
super().__init__()
self.save_hyperparameters()
self.module = SAGE(in_feats, n_hidden, n_classes, n_layers, activation, dropout)
self.lr = lr
self.loss_fcn = CrossEntropyLoss()
def training_step(self, batch, batch_idx):
input_nodes, pos_graph, neg_graph, mfgs = batch
mfgs = [mfg.int().to(device) for mfg in mfgs]
pos_graph = pos_graph.to(device)
neg_graph = neg_graph.to(device)
batch_inputs = mfgs[0].srcdata['features']
batch_labels = mfgs[-1].dstdata['labels']
batch_pred = self.module(mfgs, batch_inputs)
loss = self.loss_fcn(batch_pred, pos_graph, neg_graph)
self.log('train_loss', loss, prog_bar=True, on_step=False, on_epoch=True)
return loss
def validation_step(self, batch, batch_idx):
input_nodes, output_nodes, mfgs = batch
mfgs = [mfg.int().to(device) for mfg in mfgs]
batch_inputs = mfgs[0].srcdata['features']
batch_labels = mfgs[-1].dstdata['labels']
batch_pred = self.module(mfgs, batch_inputs)
return batch_pred
def configure_optimizers(self):
optimizer = th.optim.Adam(self.parameters(), lr=self.lr)
return optimizer
class DataModule(LightningDataModule):
def __init__(self, dataset_name, data_cpu=False, fan_out=[10, 25],
device=th.device('cpu'), batch_size=1000, num_workers=4):
super().__init__()
if dataset_name == 'reddit':
g, n_classes = load_reddit()
n_edges = g.num_edges()
reverse_eids = th.cat([
th.arange(n_edges // 2, n_edges),
th.arange(0, n_edges // 2)])
elif dataset_name == 'ogbn-products':
g, n_classes = load_ogb('ogbn-products')
n_edges = g.num_edges()
# The reverse edge of edge 0 in OGB products dataset is 1.
# The reverse edge of edge 2 is 3. So on so forth.
reverse_eids = th.arange(n_edges) ^ 1
else:
raise ValueError('unknown dataset')
train_nid = th.nonzero(g.ndata['train_mask'], as_tuple=True)[0]
val_nid = th.nonzero(g.ndata['val_mask'], as_tuple=True)[0]
test_nid = th.nonzero(~(g.ndata['train_mask'] | g.ndata['val_mask']), as_tuple=True)[0]
sampler = dgl.dataloading.MultiLayerNeighborSampler([int(_) for _ in fan_out])
dataloader_device = th.device('cpu')
if not data_cpu:
train_nid = train_nid.to(device)
val_nid = val_nid.to(device)
test_nid = test_nid.to(device)
g = g.formats(['csc'])
g = g.to(device)
dataloader_device = device
self.g = g
self.train_nid, self.val_nid, self.test_nid = train_nid, val_nid, test_nid
self.sampler = sampler
self.device = dataloader_device
self.batch_size = batch_size
self.num_workers = num_workers
self.in_feats = g.ndata['features'].shape[1]
self.n_classes = n_classes
self.reverse_eids = reverse_eids
def train_dataloader(self):
return dgl.dataloading.EdgeDataLoader(
self.g,
np.arange(self.g.num_edges()),
self.sampler,
exclude='reverse_id',
reverse_eids=self.reverse_eids,
negative_sampler=NegativeSampler(self.g, args.num_negs, args.neg_share),
device=self.device,
batch_size=self.batch_size,
shuffle=True,
drop_last=False,
num_workers=self.num_workers)
def val_dataloader(self):
# Note that the validation data loader is a NodeDataLoader
# as we want to evaluate all the node embeddings.
return dgl.dataloading.NodeDataLoader(
self.g,
np.arange(self.g.num_nodes()),
self.sampler,
device=self.device,
batch_size=self.batch_size,
shuffle=False,
drop_last=False,
num_workers=self.num_workers)
class UnsupervisedClassification(Callback):
def on_validation_epoch_start(self, trainer, pl_module):
self.val_outputs = []
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
self.val_outputs.append(outputs)
def on_validation_epoch_end(self, trainer, pl_module):
node_emb = th.cat(self.val_outputs, 0)
g = trainer.datamodule.g
labels = g.ndata['labels']
f1_micro, f1_macro = compute_acc(
node_emb, labels, trainer.datamodule.train_nid,
trainer.datamodule.val_nid, trainer.datamodule.test_nid)
pl_module.log('val_f1_micro', f1_micro)
if __name__ == '__main__':
argparser = argparse.ArgumentParser("multi-gpu training")
argparser.add_argument("--gpu", type=int, default=0)
argparser.add_argument('--dataset', type=str, default='reddit')
argparser.add_argument('--num-epochs', type=int, default=20)
argparser.add_argument('--num-hidden', type=int, default=16)
argparser.add_argument('--num-layers', type=int, default=2)
argparser.add_argument('--num-negs', type=int, default=1)
argparser.add_argument('--neg-share', default=False, action='store_true',
help="sharing neg nodes for positive nodes")
argparser.add_argument('--fan-out', type=str, default='10,25')
argparser.add_argument('--batch-size', type=int, default=10000)
argparser.add_argument('--log-every', type=int, default=20)
argparser.add_argument('--eval-every', type=int, default=1000)
argparser.add_argument('--lr', type=float, default=0.003)
argparser.add_argument('--dropout', type=float, default=0.5)
argparser.add_argument('--num-workers', type=int, default=0,
help="Number of sampling processes. Use 0 for no extra process.")
args = argparser.parse_args()
if args.gpu >= 0:
device = th.device('cuda:%d' % args.gpu)
else:
device = th.device('cpu')
datamodule = DataModule(
args.dataset, True, [int(_) for _ in args.fan_out.split(',')],
device, args.batch_size, args.num_workers)
model = SAGELightning(
datamodule.in_feats, args.num_hidden, datamodule.n_classes, args.num_layers,
F.relu, args.dropout, args.lr)
# Train
unsupervised_callback = UnsupervisedClassification()
checkpoint_callback = ModelCheckpoint(monitor='val_f1_micro', save_top_k=1)
trainer = Trainer(gpus=[args.gpu] if args.gpu != -1 else None,
max_epochs=args.num_epochs,
val_check_interval=1000,
callbacks=[checkpoint_callback, unsupervised_callback],
num_sanity_val_steps=0)
trainer.fit(model, datamodule=datamodule)