-
Notifications
You must be signed in to change notification settings - Fork 1
/
net.py
137 lines (120 loc) · 6.96 KB
/
net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from layer import *
class CoGNN(nn.Module):
def __init__(self, gcn_true, buildA_true, gcn_depth, num_nodes, l_matrix, pre_adj_list, device, predefined_A=None, static_feat=None,
dropout=0.3, subgraph_size=20, node_dim=40, dilation_exponential=1, conv_channels=32,
residual_channels=32, skip_channels=64, end_channels=128, seq_length=12, in_dim=2, out_dim=12,
layers=3, propalpha=0.05, tanhalpha=3, layer_norm_affline=True):
super(CoGNN, self).__init__()
self.gcn_true = gcn_true
self.buildA_true = buildA_true
self.num_nodes = num_nodes
self.dropout = dropout
self.predefined_A = predefined_A
self.filter_convs = nn.ModuleList()
self.gate_convs = nn.ModuleList()
self.residual_convs = nn.ModuleList()
self.skip_convs = nn.ModuleList()
self.gc = nn.ModuleList()
self.gconv1 = nn.ModuleList()
self.gconv2 = nn.ModuleList()
self.norm = nn.ModuleList()
self.start_conv = nn.Conv2d(in_channels=in_dim,
out_channels=residual_channels,
kernel_size=(1, 1))
# graph_directed_sep_init(l_matrix, subgraph_size, node_dim, device, predefined_A, alpha=tanhalpha, static_feat=static_feat)
self.seq_length = seq_length
kernel_size = 7
if dilation_exponential > 1:
self.receptive_field = int(1+(kernel_size-1)*(dilation_exponential**layers-1)/(dilation_exponential-1))
else:
self.receptive_field = layers*(kernel_size-1) + 1
for i in range(1):
if dilation_exponential > 1:
rf_size_i = int(1 + i*(kernel_size-1)*(dilation_exponential**layers-1)/(dilation_exponential-1))
else:
rf_size_i = i*layers*(kernel_size-1)+1
new_dilation = 1
for j in range(1,layers+1):
if dilation_exponential > 1:
rf_size_j = int(rf_size_i + (kernel_size-1)*(dilation_exponential**j-1)/(dilation_exponential-1))
else:
rf_size_j = rf_size_i+j*(kernel_size-1)
self.filter_convs.append(dilated_inception(residual_channels, conv_channels, dilation_factor=new_dilation))
self.gate_convs.append(dilated_inception(residual_channels, conv_channels, dilation_factor=new_dilation))
self.residual_convs.append(nn.Conv2d(in_channels=conv_channels,
out_channels=residual_channels,
kernel_size=(1, 1)))
if self.seq_length > self.receptive_field:
self.skip_convs.append(nn.Conv2d(in_channels=conv_channels,
out_channels=skip_channels,
kernel_size=(1, self.seq_length-rf_size_j+1)))
else:
self.skip_convs.append(nn.Conv2d(in_channels=conv_channels,
out_channels=skip_channels,
kernel_size=(1, self.receptive_field-rf_size_j+1)))
if self.gcn_true:
if self.buildA_true:
self.gc.append(graph_directed_sep_init(l_matrix, subgraph_size, node_dim, device, predefined_A, alpha=tanhalpha, static_feat=static_feat))
self.gconv1.append(mixprop(conv_channels, residual_channels, gcn_depth, dropout, propalpha))
self.gconv2.append(mixprop(conv_channels, residual_channels, gcn_depth, dropout, propalpha))
if self.seq_length > self.receptive_field:
self.norm.append(LayerNorm((residual_channels, num_nodes, self.seq_length - rf_size_j + 1),elementwise_affine=layer_norm_affline))
else:
self.norm.append(LayerNorm((residual_channels, num_nodes, self.receptive_field - rf_size_j + 1),elementwise_affine=layer_norm_affline))
new_dilation *= dilation_exponential
self.layers = layers
self.end_conv_1 = nn.Conv2d(in_channels=skip_channels,
out_channels=end_channels,
kernel_size=(1,1),
bias=True)
self.end_conv_2 = nn.Conv2d(in_channels=end_channels,
out_channels=out_dim,
kernel_size=(1,1),
bias=True)
if self.seq_length > self.receptive_field:
self.skip0 = nn.Conv2d(in_channels=in_dim, out_channels=skip_channels, kernel_size=(1, self.seq_length), bias=True)
self.skipE = nn.Conv2d(in_channels=residual_channels, out_channels=skip_channels, kernel_size=(1, self.seq_length-self.receptive_field+1), bias=True)
else:
self.skip0 = nn.Conv2d(in_channels=in_dim, out_channels=skip_channels, kernel_size=(1, self.receptive_field), bias=True)
self.skipE = nn.Conv2d(in_channels=residual_channels, out_channels=skip_channels, kernel_size=(1, 1), bias=True)
self.idx = torch.arange(self.num_nodes).to(device)
def forward(self, input, idx=None):
seq_len = input.size(3)
assert seq_len == self.seq_length, 'input sequence length not equal to preset sequence length'
if self.seq_length < self.receptive_field:
input = nn.functional.pad(input,(self.receptive_field-self.seq_length,0,0,0))
x = self.start_conv(input)
skip = self.skip0(F.dropout(input, self.dropout, training=self.training))
for i in range(self.layers):
if self.gcn_true:
if self.buildA_true:
if idx is None:
adp = self.gc[i](self.idx)
else:
adp = self.gc[i](idx)
else:
adp = self.predefined_A
residual = x
if self.gcn_true:
x = self.gconv1[i](x, adp)+self.gconv2[i](x, adp.transpose(1,0))
else:
x = self.residual_convs[i](x)
filter = self.filter_convs[i](x)
filter = torch.tanh(filter)
gate = self.gate_convs[i](x)
gate = torch.sigmoid(gate)
x = filter * gate
x = F.dropout(x, self.dropout, training=self.training)
s = x
s = self.skip_convs[i](s)
skip = s + skip
x = x + residual[:, :, :, -x.size(3):]
if idx is None:
x = self.norm[i](x,self.idx)
else:
x = self.norm[i](x,idx)
skip = self.skipE(x) + skip
x = F.relu(skip)
x = F.relu(self.end_conv_1(x))
x = self.end_conv_2(x)
return x