-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlayers.py
94 lines (82 loc) · 3.3 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
import torch.nn as nn
from torch.autograd import Variable
class Line_Transform(nn.Module):
def __init__(self, c_in, c_out):
super(Line_Transform, self).__init__()
self.w = nn.Parameter(torch.rand(c_out, c_in), requires_grad=True).cuda()
self.b = nn.Parameter(torch.rand(c_in, 1), requires_grad=True).cuda()
def forward(self, x):
x = torch.einsum('vf,bfn->bvn', (self.w, x)) + self.b
return x.contiguous()
class RNN(nn.Module):
def __init__(self, c_in, c_out):
super(RNN, self).__init__()
self.w1 = Line_Transform(c_in,c_out)
self.w2 = Line_Transform(c_out,c_out)
self.c_out = c_out
def forward(self, x):
shape = x.shape # b,f,n,t
h = Variable(torch.zeros((shape[0], self.c_out, shape[2]))).cuda()
out = []
for t in range(shape[3]):
input = x[:, :, :, t] # b,f,n
new_h = torch.tanh(self.w1(input)+self.w2(h))
h = new_h # b,f,n
out.append(new_h)
x = torch.stack(out, -1) # b,f,n,t
return x
class LSTM(nn.Module):
def __init__(self, c_in, c_out):
super(LSTM, self).__init__()
self.w1 = Line_Transform(c_in,c_out)
self.w2 = Line_Transform(c_out,c_out)
self.w3 = Line_Transform(c_in,c_out)
self.w4 = Line_Transform(c_out,c_out)
self.w5 = Line_Transform(c_in, c_out)
self.w6 = Line_Transform(c_out, c_out)
self.w7 = Line_Transform(c_in, c_out)
self.w8 = Line_Transform(c_out, c_out)
self.c_out = c_out
def forward(self, x):
shape = x.shape # b,f,n,t
h = Variable(torch.zeros((shape[0], self.c_out, shape[2]))).cuda()
c = Variable(torch.zeros((shape[0], self.c_out, shape[2]))).cuda()
out = []
for t in range(shape[3]):
input = x[:, :, :, t] # b,f,n
i = torch.sigmoid(self.w1(input)+self.w2(h))
f = torch.sigmoid(self.w3(input)+self.w4(h))
o = torch.sigmoid(self.w5(input)+self.w6(h))
g = torch.tanh(self.w7(input)+self.w8(h))
new_c = f*c +i*g
new_h = o*torch.tanh(new_c)
c= new_c
h = new_h # b,f,n
out.append(new_h)
x = torch.stack(out, -1) # b,f,n,t
return x
class GRU(nn.Module):
def __init__(self, c_in, c_out):
super(GRU, self).__init__()
self.w1 = Line_Transform(c_in,c_out)
self.w2 = Line_Transform(c_out,c_out)
self.w3 = Line_Transform(c_in,c_out)
self.w4 = Line_Transform(c_out,c_out)
self.w5 = Line_Transform(c_in, c_out)
self.w6 = Line_Transform(c_out, c_out)
self.c_out = c_out
def forward(self, x):
shape = x.shape # b,f,n,t
h = Variable(torch.zeros((shape[0], self.c_out, shape[2]))).cuda()
out = []
for t in range(shape[3]):
input = x[:, :, :, t] # b,f,n
z = torch.sigmoid(self.w1(input)+self.w2(h))
r = torch.sigmoid(self.w3(input)+self.w4(h))
h_ = torch.tanh(self.w5(input)+self.w6(r * h))
new_h = z * h + (1 - z) * h_
h = new_h # b,f,n
out.append(new_h)
x = torch.stack(out, -1) # b,f,n,t
return x