-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathlayers.py
164 lines (149 loc) · 6.48 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import torch
import torch.nn as nn
import dgl.nn.pytorch as dglnn
class NodeNorm(nn.Module):
def __init__(self, nn_type="n", unbiased=False, eps=1e-5, power_root=2):
super(NodeNorm, self).__init__()
self.unbiased = unbiased
self.eps = eps
self.nn_type = nn_type
self.power = 1 / power_root
def forward(self, x):
if self.nn_type == "n":
mean = torch.mean(x, dim=1, keepdim=True)
std = (
torch.var(x, unbiased=self.unbiased, dim=1, keepdim=True) + self.eps
).sqrt()
x = (x - mean) / std
elif self.nn_type == "v":
std = (
torch.var(x, unbiased=self.unbiased, dim=1, keepdim=True) + self.eps
).sqrt()
x = x / std
elif self.nn_type == "m":
mean = torch.mean(x, dim=1, keepdim=True)
x = x - mean
elif self.nn_type == "srv": # squre root of variance
std = (
torch.var(x, unbiased=self.unbiased, dim=1, keepdim=True) + self.eps
).sqrt()
x = x / torch.sqrt(std)
elif self.nn_type == "pr":
std = (
torch.var(x, unbiased=self.unbiased, dim=1, keepdim=True) + self.eps
).sqrt()
x = x / torch.pow(std, self.power)
return x
def __repr__(self):
original_str = super().__repr__()
components = list(original_str)
nn_type_str = f"nn_type={self.nn_type}"
components.insert(-1, nn_type_str)
new_str = "".join(components)
return new_str
def get_normalization(norm_type, num_channels=None):
if norm_type is None:
norm = None
elif norm_type == "batch":
norm = nn.BatchNorm1d(num_features=num_channels)
elif norm_type == "node_n":
norm = NodeNorm(nn_type="n")
elif norm_type == "node_v":
norm = NodeNorm(nn_type="v")
elif norm_type == "node_m":
norm = NodeNorm(nn_type="m")
elif norm_type == "node_srv":
norm = NodeNorm(nn_type="srv")
elif norm_type.find("node_pr") != -1:
power_root = norm_type.split("_")[-1]
power_root = int(power_root)
norm = NodeNorm(nn_type="pr", power_root=power_root)
elif norm_type == "layer":
norm = nn.LayerNorm(normalized_shape=num_channels)
else:
raise NotImplementedError
return norm
class GNNBasicBlock(nn.Module):
def __init__(self, layer_type, block_type, activation, normalization=None, **core_layer_hyperparms):
super(GNNBasicBlock, self).__init__()
self.layer_type = layer_type
self.block_type = block_type
if self.layer_type in ['gcn', 'gcn_res']:
self.core_layer_type = 'gcn'
self.core_layer = dglnn.GraphConv(in_feats=core_layer_hyperparms['in_channels'],
out_feats=core_layer_hyperparms['out_channels'],
bias=core_layer_hyperparms['bias']
)
elif self.layer_type in ['gat', 'gat_res']:
self.core_layer_type = 'gat'
self.core_layer = dglnn.GATConv(in_feats=core_layer_hyperparms['in_channels'],
out_feats=int(core_layer_hyperparms['out_channels'] / core_layer_hyperparms['num_heads']),
num_heads=core_layer_hyperparms['num_heads'],
feat_drop=core_layer_hyperparms['feat_drop'],
attn_drop=core_layer_hyperparms['attn_drop']
)
elif self.layer_type in ['sage', 'sage_res']:
self.core_layer_type = 'sage'
self.core_layer = dglnn.SAGEConv(in_feats=core_layer_hyperparms['in_channels'],
out_feats=core_layer_hyperparms['out_channels'],
aggregator_type='mean',
bias=core_layer_hyperparms['bias'])
else:
raise NotImplementedError
acti_type, acti_hyperparam = activation
if acti_type == 'relu':
self.activation = nn.ReLU(inplace=acti_hyperparam)
elif acti_type == 'lkrelu':
self.activation = nn.LeakyReLU(negative_slope=acti_hyperparam)
elif acti_type == 'elu':
self.activation = nn.ELU(inplace=acti_hyperparam)
elif acti_type == 'no':
self.activation = None
else:
raise NotImplementedError
if 'n' in block_type.split('_'):
self.node_norm = get_normalization(
norm_type=normalization, num_channels=core_layer_hyperparms['out_channels']
)
self.block_type_str = self.get_block_type_str()
def forward(self, graph, x):
if self.core_layer_type in ['gcn', 'sage']:
x1 = self.core_layer(graph, x)
elif self.core_layer_type in ['gat', ]:
x1 = self.core_layer(graph, x).flatten(1)
else:
x1 = self.core_layer(x)
if self.block_type == 'v': # vallina layers
if self.activation is not None:
x1 = self.activation(x1)
x = x1
elif self.block_type == 'a_r': # activation then adding residual link
x1 = self.activation(x1)
x = x1 + x
elif self.block_type == 'n_a': # nodenorm then activation
x = self.node_norm(x1)
x = self.activation(x)
elif self.block_type == 'n_a_r': # nodenorm, activation then adding residual link
x1 = self.node_norm(x1)
x1 = self.activation(x1)
x = x1 + x
return x
def get_block_type_str(self):
if self.block_type == 'v':
block_type_str = 'vallina'
elif self.block_type == 'a_r':
block_type_str = 'activation_residual'
elif self.block_type == 'n_a_r':
block_type_str = 'normalization_activation_residual'
elif self.block_type == 'n_a':
block_type_str = 'normalization_activation'
else:
raise NotImplementedError
return block_type_str
def __repr__(self):
original_str = super().__repr__()
components = original_str.split('\n')
block_type_str = f' (block_type): {self.block_type_str}'
components.insert(-1, block_type_str)
new_str = '\n'.join(components)
return new_str