-
Notifications
You must be signed in to change notification settings - Fork 140
/
model.py
198 lines (159 loc) · 6.84 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import os
import math
import sys
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as Func
from torch.nn import init
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module
import torch.optim as optim
class ConvTemporalGraphical(nn.Module):
#Source : https://github.com/yysijie/st-gcn/blob/master/net/st_gcn.py
r"""The basic module for applying a graph convolution.
Args:
in_channels (int): Number of channels in the input sequence data
out_channels (int): Number of channels produced by the convolution
kernel_size (int): Size of the graph convolving kernel
t_kernel_size (int): Size of the temporal convolving kernel
t_stride (int, optional): Stride of the temporal convolution. Default: 1
t_padding (int, optional): Temporal zero-padding added to both sides of
the input. Default: 0
t_dilation (int, optional): Spacing between temporal kernel elements.
Default: 1
bias (bool, optional): If ``True``, adds a learnable bias to the output.
Default: ``True``
Shape:
- Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format
- Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format
- Output[0]: Outpu graph sequence in :math:`(N, out_channels, T_{out}, V)` format
- Output[1]: Graph adjacency matrix for output data in :math:`(K, V, V)` format
where
:math:`N` is a batch size,
:math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`,
:math:`T_{in}/T_{out}` is a length of input/output sequence,
:math:`V` is the number of graph nodes.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
t_kernel_size=1,
t_stride=1,
t_padding=0,
t_dilation=1,
bias=True):
super(ConvTemporalGraphical,self).__init__()
self.kernel_size = kernel_size
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=(t_kernel_size, 1),
padding=(t_padding, 0),
stride=(t_stride, 1),
dilation=(t_dilation, 1),
bias=bias)
def forward(self, x, A):
assert A.size(0) == self.kernel_size
x = self.conv(x)
x = torch.einsum('nctv,tvw->nctw', (x, A))
return x.contiguous(), A
class st_gcn(nn.Module):
r"""Applies a spatial temporal graph convolution over an input graph sequence.
Args:
in_channels (int): Number of channels in the input sequence data
out_channels (int): Number of channels produced by the convolution
kernel_size (tuple): Size of the temporal convolving kernel and graph convolving kernel
stride (int, optional): Stride of the temporal convolution. Default: 1
dropout (int, optional): Dropout rate of the final output. Default: 0
residual (bool, optional): If ``True``, applies a residual mechanism. Default: ``True``
Shape:
- Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format
- Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format
- Output[0]: Outpu graph sequence in :math:`(N, out_channels, T_{out}, V)` format
- Output[1]: Graph adjacency matrix for output data in :math:`(K, V, V)` format
where
:math:`N` is a batch size,
:math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`,
:math:`T_{in}/T_{out}` is a length of input/output sequence,
:math:`V` is the number of graph nodes.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
use_mdn = False,
stride=1,
dropout=0,
residual=True):
super(st_gcn,self).__init__()
# print("outstg",out_channels)
assert len(kernel_size) == 2
assert kernel_size[0] % 2 == 1
padding = ((kernel_size[0] - 1) // 2, 0)
self.use_mdn = use_mdn
self.gcn = ConvTemporalGraphical(in_channels, out_channels,
kernel_size[1])
self.tcn = nn.Sequential(
nn.BatchNorm2d(out_channels),
nn.PReLU(),
nn.Conv2d(
out_channels,
out_channels,
(kernel_size[0], 1),
(stride, 1),
padding,
),
nn.BatchNorm2d(out_channels),
nn.Dropout(dropout, inplace=True),
)
if not residual:
self.residual = lambda x: 0
elif (in_channels == out_channels) and (stride == 1):
self.residual = lambda x: x
else:
self.residual = nn.Sequential(
nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
stride=(stride, 1)),
nn.BatchNorm2d(out_channels),
)
self.prelu = nn.PReLU()
def forward(self, x, A):
res = self.residual(x)
x, A = self.gcn(x, A)
x = self.tcn(x) + res
if not self.use_mdn:
x = self.prelu(x)
return x, A
class social_stgcnn(nn.Module):
def __init__(self,n_stgcnn =1,n_txpcnn=1,input_feat=2,output_feat=5,
seq_len=8,pred_seq_len=12,kernel_size=3):
super(social_stgcnn,self).__init__()
self.n_stgcnn= n_stgcnn
self.n_txpcnn = n_txpcnn
self.st_gcns = nn.ModuleList()
self.st_gcns.append(st_gcn(input_feat,output_feat,(kernel_size,seq_len)))
for j in range(1,self.n_stgcnn):
self.st_gcns.append(st_gcn(output_feat,output_feat,(kernel_size,seq_len)))
self.tpcnns = nn.ModuleList()
self.tpcnns.append(nn.Conv2d(seq_len,pred_seq_len,3,padding=1))
for j in range(1,self.n_txpcnn):
self.tpcnns.append(nn.Conv2d(pred_seq_len,pred_seq_len,3,padding=1))
self.tpcnn_ouput = nn.Conv2d(pred_seq_len,pred_seq_len,3,padding=1)
self.prelus = nn.ModuleList()
for j in range(self.n_txpcnn):
self.prelus.append(nn.PReLU())
def forward(self,v,a):
for k in range(self.n_stgcnn):
v,a = self.st_gcns[k](v,a)
v = v.view(v.shape[0],v.shape[2],v.shape[1],v.shape[3])
v = self.prelus[0](self.tpcnns[0](v))
for k in range(1,self.n_txpcnn-1):
v = self.prelus[k](self.tpcnns[k](v)) + v
v = self.tpcnn_ouput(v)
v = v.view(v.shape[0],v.shape[2],v.shape[1],v.shape[3])
return v,a