-
Notifications
You must be signed in to change notification settings - Fork 7
/
models.py
128 lines (91 loc) · 4.27 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
import torch.nn as nn
import torch
import numpy as np
import copy
"""
The following code is borrowed from BYOL, SelfGNN
and slightly modified for BGRL
"""
class EMA:
def __init__(self, beta, epochs):
super().__init__()
self.beta = beta
self.step = 0
self.total_steps = epochs
def update_average(self, old, new):
if old is None:
return new
beta = 1 - (1 - self.beta) * (np.cos(np.pi * self.step / self.total_steps) + 1) / 2.0
self.step += 1
return old * beta + (1 - beta) * new
def loss_fn(x, y):
x = F.normalize(x, dim=-1, p=2)
y = F.normalize(y, dim=-1, p=2)
return 2 - 2 * (x * y).sum(dim=-1)
def update_moving_average(ema_updater, ma_model, current_model):
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = ema_updater.update_average(old_weight, up_weight)
def set_requires_grad(model, val):
for p in model.parameters():
p.requires_grad = val
class Encoder(nn.Module):
def __init__(self, layer_config, dropout=None, project=False, **kwargs):
super().__init__()
self.conv1 = GCNConv(layer_config[0], layer_config[1])
self.bn1 = nn.BatchNorm1d(layer_config[1], momentum = 0.01)
self.prelu1 = nn.PReLU()
self.conv2 = GCNConv(layer_config[1],layer_config[2])
self.bn2 = nn.BatchNorm1d(layer_config[2], momentum = 0.01)
self.prelu2 = nn.PReLU()
def forward(self, x, edge_index, edge_weight=None):
x = self.conv1(x, edge_index, edge_weight=edge_weight)
x = self.prelu1(self.bn1(x))
x = self.conv2(x, edge_index, edge_weight=edge_weight)
x = self.prelu2(self.bn2(x))
return x
def init_weights(m):
if type(m) == nn.Linear:
torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.01)
class BGRL(nn.Module):
def __init__(self, layer_config, pred_hid, dropout=0.0, moving_average_decay=0.99, epochs=1000, **kwargs):
super().__init__()
self.student_encoder = Encoder(layer_config=layer_config, dropout=dropout, **kwargs)
self.teacher_encoder = copy.deepcopy(self.student_encoder)
set_requires_grad(self.teacher_encoder, False)
self.teacher_ema_updater = EMA(moving_average_decay, epochs)
rep_dim = layer_config[-1]
self.student_predictor = nn.Sequential(nn.Linear(rep_dim, pred_hid), nn.PReLU(), nn.Linear(pred_hid, rep_dim))
self.student_predictor.apply(init_weights)
def reset_moving_average(self):
del self.teacher_encoder
self.teacher_encoder = None
def update_moving_average(self):
assert self.teacher_encoder is not None, 'teacher encoder has not been created yet'
update_moving_average(self.teacher_ema_updater, self.teacher_encoder, self.student_encoder)
def forward(self, x1, x2, edge_index_v1, edge_index_v2, edge_weight_v1=None, edge_weight_v2=None):
v1_student = self.student_encoder(x=x1, edge_index=edge_index_v1, edge_weight=edge_weight_v1)
v2_student = self.student_encoder(x=x2, edge_index=edge_index_v2, edge_weight=edge_weight_v2)
v1_pred = self.student_predictor(v1_student)
v2_pred = self.student_predictor(v2_student)
with torch.no_grad():
v1_teacher = self.teacher_encoder(x=x1, edge_index=edge_index_v1, edge_weight=edge_weight_v1)
v2_teacher = self.teacher_encoder(x=x2, edge_index=edge_index_v2, edge_weight=edge_weight_v2)
loss1 = loss_fn(v1_pred, v2_teacher.detach())
loss2 = loss_fn(v2_pred, v1_teacher.detach())
loss = loss1 + loss2
return v1_student, v2_student, loss.mean()
class LogisticRegression(nn.Module):
def __init__(self, num_dim, num_class):
super().__init__()
self.linear = nn.Linear(num_dim, num_class)
torch.nn.init.xavier_uniform_(self.linear.weight.data)
self.linear.bias.data.fill_(0.0)
self.cross_entropy = nn.CrossEntropyLoss()
def forward(self, x, y):
logits = self.linear(x)
loss = self.cross_entropy(logits, y)
return logits, loss