-
Notifications
You must be signed in to change notification settings - Fork 2
/
PLUMCN-class.py
162 lines (143 loc) · 6.98 KB
/
PLUMCN-class.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import torch.nn as nn
import torch
import torch.nn.functional as F
#selfattention
class selfattention(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.query = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1, stride=1)
self.key = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1, stride=1)
self.value = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1)
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)
def forward(self, input):
batch_size, channels, height, width = input.shape
# input: B, C, H, W -> q: B, H * W, C // 8
q = self.query(input).view(batch_size, -1, height * width).permute(0, 2, 1)
# input: B, C, H, W -> k: B, C // 8, H * W
k = self.key(input).view(batch_size, -1, height * width)
# input: B, C, H, W -> v: B, C, H * W
v = self.value(input).view(batch_size, -1, height * width)
# q: B, H * W, C // 8 x k: B, C // 8, H * W -> attn_matrix: B, H * W, H * W
attn_matrix = torch.bmm(q, k)
attn_matrix = self.softmax(attn_matrix)
out = torch.bmm(v, attn_matrix.permute(0, 2, 1))
out = out.view(*input.shape)
return self.gamma * out + input
#MGCE
class MGCE(nn.Module):
def __init__(self,c_in,c_out,feature_num):
super(MGCE, self).__init__()
self.dcnn2d1_MGCE = nn.Sequential(
nn.Conv2d(c_in, c_out, kernel_size=(1, 2), padding=(0, 0), dilation=(1, 2)),
nn.BatchNorm2d(c_out),
nn.ReLU(True)
)
self.fcD11_MGCE = nn.Linear(feature_num, feature_num, bias=True)
self.fcD12_MGCE = nn.Linear(feature_num, feature_num, bias=True)
self.fcD13_MGCE = nn.Linear(feature_num, feature_num, bias=True)
self.dcnn2d2_MGCE = nn.Sequential(
nn.Conv2d(c_out, c_out, kernel_size=(1, 2), padding=(0, 1), dilation=(1, 2)),
nn.BatchNorm2d(c_out),
nn.ReLU(True)
)
self.cnn2dRes_MGCE = nn.Sequential(
nn.Conv2d(c_in, c_out, kernel_size=(1, 2), padding=(0, 0), dilation=(1, 2)),
nn.BatchNorm2d(c_out),
nn.ReLU(True)
)
self.cnn2ddown_MGCE = nn.Sequential(
nn.Conv2d(c_out, c_out, kernel_size=(1, 2), padding=(0, 1), dilation=(1, 2)),
nn.BatchNorm2d(c_out),
nn.ReLU(True)
)
def forward(self, X0,X,AC0,AD0,AW0):
B = X.shape[0]
X = self.dcnn2d1_MGCE1(X)
X = X.view(B, -1, X.shape[2], X.shape[3])
c = X.shape[1]
AC0_MGCE1 = AC0.view(B, 1, AC0.shape[1], AC0.shape[2])
AC0_MGCE1 = torch.broadcast_to(AC0_MGCE1, (
B, X.shape[1], AC0_MGCE1.shape[2], AC0_MGCE1.shape[3]))
AD0_MGCE1 = AD0.view(B, 1, AD0.shape[1], AD0.shape[2])
AD0_MGCE1 = torch.broadcast_to(AD0_MGCE1,
(B, X.shape[1], AD0_MGCE1.shape[2],
AD0_MGCE1.shape[3]))
AW0_MGCE1 = AW0.view(B, 1, AW0.shape[1], AW0.shape[2])
AW0_MGCE1 = torch.broadcast_to(AW0_MGCE1,
(B, X.shape[1], AW0_MGCE1.shape[2],
AW0_MGCE1.shape[3]))
AC0_MGCE1 = AC0_MGCE1.reshape(-1, AC0_MGCE1.shape[2], AC0_MGCE1.shape[3])
AD0_MGCE1 = AD0_MGCE1.reshape(-1, AD0_MGCE1.shape[2], AD0_MGCE1.shape[3])
AW0_MGCE1 = AW0_MGCE1.reshape(-1, AW0_MGCE1.shape[2], AW0_MGCE1.shape[3])
X = X.view(-1, X.shape[2], X.shape[3])
X = self.fcD11_MGCE1(torch.bmm(AC0_MGCE1, X)) + self.fcD12_MGCE1(torch.bmm(AD0_MGCE1, X)) + self.fcD13_MGCE1(
torch.bmm(AW0_MGCE1, X))
X = X.view(B, c, X.shape[1], X.shape[2])
X = self.dcnn2d2_MGCE1(X)
X = X + self.cnn2dRes_MGCE1(X0)
X = F.relu(X)
X = self.cnn2ddown_MGCE1(X)
MGCE = X
return MGCE
# MGCE
#MGCD
class MGCD(nn.Module):
def __init__(self, c_in, c_out, feature_num):
super(MGCD, self).__init__()
self.selfattentiont1_MGCD = selfattention(c_in)
self.selfattentiont2_MGCD = selfattention(c_in)
self.tcnn2ddown_MGCD = nn.Sequential(
nn.ConvTranspose2d(in_channels=c_in, out_channels=c_out, kernel_size=(1, 2), stride=(1, 2), padding=(0, 0)),
nn.BatchNorm2d(c_out),
nn.ReLU(True)
)
self.fcD11_MGCD = nn.Linear(feature_num, feature_num, bias=True)
self.fcD12_MGCD = nn.Linear(feature_num, feature_num, bias=True)
self.fcD13_MGCD = nn.Linear(feature_num, feature_num, bias=True)
self.cnn2dup_MGCD = nn.Sequential(
nn.ConvTranspose2d(c_out, c_out, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1), dilation=1),
nn.BatchNorm2d(c_out),
nn.ReLU(True)
)
self.cnn2dRes_MGCD = nn.Sequential(
nn.ConvTranspose2d(c_in, c_out, kernel_size=(1, 2), padding=(0, 0), dilation=(1, 2)),
nn.BatchNorm2d(c_out),
nn.ReLU(True)
)
self.cnn2dup1_MGCD = nn.Sequential(
nn.ConvTranspose2d(c_out, c_out, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1), dilation=1),
nn.BatchNorm2d(c_out),
nn.ReLU(True)
)
def forward(self, MGCE_out,X,AC0,AD0,AW0):
B = X.shape[0]
X=self.selfattentiont1_MGCD(X)*torch.sigmoid(self.selfattentiont1_MGCD(MGCE_out))
X0 = X
X = self.tcnn2ddown_MGCD1(X)
AC0_MGCD1 = AC0.view(B, 1, AC0.shape[1], AC0.shape[2])
AC0_MGCD1 = torch.broadcast_to(AC0_MGCD1,
(B, X.shape[1], AC0_MGCD1.shape[2],
AC0_MGCD1.shape[3]))
AD0_MGCD1 = AD0.view(B, 1, AD0.shape[1], AD0.shape[2])
AD0_MGCD1 = torch.broadcast_to(AD0_MGCD1,
(B, X.shape[1], AD0_MGCD1.shape[2],
AD0_MGCD1.shape[3]))
AW0_MGCD1 = AW0.view(B, 1, AW0.shape[1], AW0.shape[2])
AW0_MGCD1 = torch.broadcast_to(AW0_MGCD1,
(B, X.shape[1], AW0_MGCD1.shape[2],
AW0_MGCD1.shape[3]))
AC0_MGCD1 = AC0_MGCD1.reshape(-1, AC0_MGCD1.shape[2], AC0_MGCD1.shape[3])
AD0_MGCD1 = AD0_MGCD1.reshape(-1, AD0_MGCD1.shape[2], AD0_MGCD1.shape[3])
AW0_MGCD1 = AW0_MGCD1.reshape(-1, AW0_MGCD1.shape[2], AW0_MGCD1.shape[3])
X = X.view(-1, X.shape[2], X.shape[3])
X = self.fcD11_MGCD1(torch.bmm(AC0_MGCD1, X)) + self.fcD12_MGCD1(torch.bmm(AD0_MGCD1, X)) + self.fcD13_MGCD1(
torch.bmm(AW0_MGCD1, X))
X = X.view(B, -1, X.shape[1], X.shape[2])
X = self.cnn2dup_MGCD1(X)
X = X + self.cnn2dRes_MGCD1(X0)
X = F.relu(X)
X = self.cnn2dup1_MGCD1(X)
MGCD = X
return MGCD