-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmodel.py
124 lines (95 loc) · 3.39 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import torch
import math
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch_geometric.nn import GCNConv, SAGEConv, SGConv
class MLPLayer(nn.Module):
def __init__(self, in_channels, out_channels, bias=True):
super().__init__()
self.weight = Parameter(torch.FloatTensor(in_channels, out_channels))
if bias:
self.bias = Parameter(torch.FloatTensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)
def forward(self, x):
output = torch.mm(x, self.weight)
if self.bias is not None:
return output + self.bias
else:
return output
class MLP(nn.Module):
def __init__(self, in_channel, out_channel, hidden):
super().__init__()
self.l1 = nn.Linear(in_channel, hidden)
self.l2 = nn.Linear(hidden, out_channel)
def forward(self, x, proxy=None):
x = self.l1(x)
x = F.relu(x)
if proxy is not None:
x = x + proxy
x = self.l2(x)
return x, x
class GCN(nn.Module):
def __init__(self, in_channel, out_channel, hidden):
super().__init__()
self.l1 = GCNConv(in_channel, hidden)
self.l2 = GCNConv(hidden, out_channel)
self.reset_parameters()
def reset_parameters(self):
for name, para in self.named_parameters():
para.data.uniform_()
def forward(self, x, edge_index):
x = self.l1(x, edge_index)
x = F.relu(x)
x = self.l2(x, edge_index)
return x, x
class SAGE(nn.Module):
def __init__(self, in_channel, out_channel, hidden):
super().__init__()
self.l1 = SAGEConv(in_channel, hidden)
self.l2 = SAGEConv(hidden, out_channel)
self.reset_parameters()
def reset_parameters(self):
for name, para in self.named_parameters():
para.data.uniform_()
def forward(self, x, edge_index):
x = self.l1(x, edge_index)
x = F.relu(x)
x = self.l2(x, edge_index)
return x, x
class SGC(nn.Module):
def __init__(self, in_channel, out_channel, hidden):
super().__init__()
self.l1 = SGConv(in_channel, out_channel, K=2)
self.reset_parameters()
def reset_parameters(self):
for name, para in self.named_parameters():
para.data.uniform_()
def forward(self, x, edge_index):
x = self.l1(x, edge_index)
return x, x
class Encoder(nn.Module):
def __init__(self, in_channel, out_channel):
super().__init__()
self.l1 = MLPLayer(in_channel, out_channel)
def forward(self, x, proxy=None):
x = self.l1(x)
x = F.relu(x)
return x
class Classifier(nn.Module):
def __init__(self, in_channel, out_channel):
super().__init__()
self.l2 = MLPLayer(in_channel, out_channel)
def forward(self, x, proxy=None):
if proxy is not None:
x = x + proxy
# x1 = torch.concat([x1, proxy], dim=1)
x = self.l2(x)
return x