-
Notifications
You must be signed in to change notification settings - Fork 97
/
Copy pathunet_models.py
403 lines (313 loc) · 15 KB
/
unet_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
from torch import nn
from torch.nn import functional as F
import torch
from torchvision import models
import torchvision
"""
This script has been taken (and modified) from :
https://github.com/ternaus/TernausNet
@ARTICLE{arXiv:1801.05746,
author = {V. Iglovikov and A. Shvets},
title = {TernausNet: U-Net with VGG11 Encoder Pre-Trained on ImageNet for Image Segmentation},
journal = {ArXiv e-prints},
eprint = {1801.05746},
year = 2018
}
"""
def conv3x3(in_, out):
return nn.Conv2d(in_, out, 3, padding=1)
class ConvRelu(nn.Module):
def __init__(self, in_, out):
super().__init__()
self.conv = conv3x3(in_, out)
self.activation = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.activation(x)
return x
class NoOperation(nn.Module):
def forward(self, x):
return x
class DecoderBlock(nn.Module):
def __init__(self, in_channels, middle_channels, out_channels):
super().__init__()
self.block = nn.Sequential(
ConvRelu(in_channels, middle_channels),
nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.block(x)
class UNet11(nn.Module):
def __init__(self, num_classes=1, num_filters=32, pretrained=False):
"""
:param num_classes:
:param num_filters:
:param pretrained:
False - no pre-trained network is used
True - encoder is pre-trained with VGG11
"""
super().__init__()
self.pool = nn.MaxPool2d(2, 2)
self.encoder = models.vgg11(pretrained=pretrained).features
self.relu = self.encoder[1]
self.conv1 = self.encoder[0]
self.conv2 = self.encoder[3]
self.conv3s = self.encoder[6]
self.conv3 = self.encoder[8]
self.conv4s = self.encoder[11]
self.conv4 = self.encoder[13]
self.conv5s = self.encoder[16]
self.conv5 = self.encoder[18]
self.center = DecoderBlock(num_filters * 8 * 2, num_filters * 8 * 2, num_filters * 8)
self.dec5 = DecoderBlock(num_filters * (16 + 8), num_filters * 8 * 2, num_filters * 8)
self.dec4 = DecoderBlock(num_filters * (16 + 8), num_filters * 8 * 2, num_filters * 4)
self.dec3 = DecoderBlock(num_filters * (8 + 4), num_filters * 4 * 2, num_filters * 2)
self.dec2 = DecoderBlock(num_filters * (4 + 2), num_filters * 2 * 2, num_filters)
self.dec1 = ConvRelu(num_filters * (2 + 1), num_filters)
self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)
def forward(self, x):
conv1 = self.relu(self.conv1(x))
conv2 = self.relu(self.conv2(self.pool(conv1)))
conv3s = self.relu(self.conv3s(self.pool(conv2)))
conv3 = self.relu(self.conv3(conv3s))
conv4s = self.relu(self.conv4s(self.pool(conv3)))
conv4 = self.relu(self.conv4(conv4s))
conv5s = self.relu(self.conv5s(self.pool(conv4)))
conv5 = self.relu(self.conv5(conv5s))
center = self.center(self.pool(conv5))
dec5 = self.dec5(torch.cat([center, conv5], 1))
dec4 = self.dec4(torch.cat([dec5, conv4], 1))
dec3 = self.dec3(torch.cat([dec4, conv3], 1))
dec2 = self.dec2(torch.cat([dec3, conv2], 1))
dec1 = self.dec1(torch.cat([dec2, conv1], 1))
return self.final(dec1)
def unet11(pretrained=False, **kwargs):
"""
pretrained:
False - no pre-trained network is used
True - encoder is pre-trained with VGG11
carvana - all weights are pre-trained on
Kaggle: Carvana dataset https://www.kaggle.com/c/carvana-image-masking-challenge
"""
model = UNet11(pretrained=pretrained, **kwargs)
if pretrained == 'carvana':
state = torch.load('TernausNet.pt')
model.load_state_dict(state['model'])
return model
class DecoderBlockV2(nn.Module):
def __init__(self, in_channels, middle_channels, out_channels, is_deconv=True):
super(DecoderBlockV2, self).__init__()
self.in_channels = in_channels
if is_deconv:
"""
Paramaters for Deconvolution were chosen to avoid artifacts, following
link https://distill.pub/2016/deconv-checkerboard/
"""
self.block = nn.Sequential(
ConvRelu(in_channels, middle_channels),
nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=4, stride=2,
padding=1),
nn.ReLU(inplace=True)
)
else:
self.block = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear'),
ConvRelu(in_channels, middle_channels),
ConvRelu(middle_channels, out_channels),
)
def forward(self, x):
return self.block(x)
class AlbuNet(nn.Module):
"""
UNet (https://arxiv.org/abs/1505.04597) with Resnet34(https://arxiv.org/abs/1512.03385) encoder
Proposed by Alexander Buslaev: https://www.linkedin.com/in/al-buslaev/
"""
def __init__(self, num_classes=1, num_filters=32, pretrained=False, is_deconv=False):
"""
:param num_classes:
:param num_filters:
:param pretrained:
False - no pre-trained network is used
True - encoder is pre-trained with resnet34
:is_deconv:
False: bilinear interpolation is used in decoder
True: deconvolution is used in decoder
"""
super().__init__()
self.num_classes = num_classes
self.pool = nn.MaxPool2d(2, 2)
self.encoder = torchvision.models.resnet34(pretrained=pretrained)
self.relu = nn.ReLU(inplace=True)
self.conv1 = nn.Sequential(self.encoder.conv1,
self.encoder.bn1,
self.encoder.relu,
self.pool)
self.conv2 = self.encoder.layer1
self.conv3 = self.encoder.layer2
self.conv4 = self.encoder.layer3
self.conv5 = self.encoder.layer4
self.center = DecoderBlockV2(512, num_filters * 8 * 2, num_filters * 8, is_deconv)
self.dec5 = DecoderBlockV2(512 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
self.dec4 = DecoderBlockV2(256 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
self.dec3 = DecoderBlockV2(128 + num_filters * 8, num_filters * 4 * 2, num_filters * 2, is_deconv)
self.dec2 = DecoderBlockV2(64 + num_filters * 2, num_filters * 2 * 2, num_filters * 2 * 2, is_deconv)
self.dec1 = DecoderBlockV2(num_filters * 2 * 2, num_filters * 2 * 2, num_filters, is_deconv)
self.dec0 = ConvRelu(num_filters, num_filters)
self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)
def forward(self, x):
conv1 = self.conv1(x)
conv2 = self.conv2(conv1)
conv3 = self.conv3(conv2)
conv4 = self.conv4(conv3)
conv5 = self.conv5(conv4)
center = self.center(self.pool(conv5))
dec5 = self.dec5(torch.cat([center, conv5], 1))
dec4 = self.dec4(torch.cat([dec5, conv4], 1))
dec3 = self.dec3(torch.cat([dec4, conv3], 1))
dec2 = self.dec2(torch.cat([dec3, conv2], 1))
dec1 = self.dec1(dec2)
dec0 = self.dec0(dec1)
return self.final(dec0)
class UNetVGG16(nn.Module):
"""PyTorch U-Net model using VGG16 encoder.
UNet: https://arxiv.org/abs/1505.04597
VGG: https://arxiv.org/abs/1409.1556
Proposed by Vladimir Iglovikov and Alexey Shvets: https://github.com/ternaus/TernausNet
Args:
num_classes (int): Number of output classes.
num_filters (int, optional): Number of filters in the last layer of decoder. Defaults to 32.
dropout_2d (float, optional): Probability factor of dropout layer before output layer. Defaults to 0.2.
pretrained (bool, optional):
False - no pre-trained weights are being used.
True - VGG encoder is pre-trained on ImageNet.
Defaults to False.
is_deconv (bool, optional):
False: bilinear interpolation is used in decoder.
True: deconvolution is used in decoder.
Defaults to False.
"""
def __init__(self, num_classes=1, num_filters=32, dropout_2d=0.2, pretrained=False, is_deconv=False):
super().__init__()
self.num_classes = num_classes
self.dropout_2d = dropout_2d
self.pool = nn.MaxPool2d(2, 2)
self.encoder = torchvision.models.vgg16(pretrained=pretrained).features
self.relu = nn.ReLU(inplace=True)
self.conv1 = nn.Sequential(self.encoder[0],
self.relu,
self.encoder[2],
self.relu)
self.conv2 = nn.Sequential(self.encoder[5],
self.relu,
self.encoder[7],
self.relu)
self.conv3 = nn.Sequential(self.encoder[10],
self.relu,
self.encoder[12],
self.relu,
self.encoder[14],
self.relu)
self.conv4 = nn.Sequential(self.encoder[17],
self.relu,
self.encoder[19],
self.relu,
self.encoder[21],
self.relu)
self.conv5 = nn.Sequential(self.encoder[24],
self.relu,
self.encoder[26],
self.relu,
self.encoder[28],
self.relu)
self.center = DecoderBlockV2(512, num_filters * 8 * 2, num_filters * 8, is_deconv)
self.dec5 = DecoderBlockV2(512 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
self.dec4 = DecoderBlockV2(512 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
self.dec3 = DecoderBlockV2(256 + num_filters * 8, num_filters * 4 * 2, num_filters * 2, is_deconv)
self.dec2 = DecoderBlockV2(128 + num_filters * 2, num_filters * 2 * 2, num_filters, is_deconv)
self.dec1 = ConvRelu(64 + num_filters, num_filters)
self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)
def forward(self, x):
conv1 = self.conv1(x)
conv2 = self.conv2(self.pool(conv1))
conv3 = self.conv3(self.pool(conv2))
conv4 = self.conv4(self.pool(conv3))
conv5 = self.conv5(self.pool(conv4))
center = self.center(self.pool(conv5))
dec5 = self.dec5(torch.cat([center, conv5], 1))
dec4 = self.dec4(torch.cat([dec5, conv4], 1))
dec3 = self.dec3(torch.cat([dec4, conv3], 1))
dec2 = self.dec2(torch.cat([dec3, conv2], 1))
dec1 = self.dec1(torch.cat([dec2, conv1], 1))
return self.final(F.dropout2d(dec1, p=self.dropout_2d))
class UNetResNet(nn.Module):
"""PyTorch U-Net model using ResNet(34, 101 or 152) encoder.
UNet: https://arxiv.org/abs/1505.04597
ResNet: https://arxiv.org/abs/1512.03385
Proposed by Alexander Buslaev: https://www.linkedin.com/in/al-buslaev/
Args:
encoder_depth (int): Depth of a ResNet encoder (34, 101 or 152).
num_classes (int): Number of output classes.
num_filters (int, optional): Number of filters in the last layer of decoder. Defaults to 32.
dropout_2d (float, optional): Probability factor of dropout layer before output layer. Defaults to 0.2.
pretrained (bool, optional):
False - no pre-trained weights are being used.
True - ResNet encoder is pre-trained on ImageNet.
Defaults to False.
is_deconv (bool, optional):
False: bilinear interpolation is used in decoder.
True: deconvolution is used in decoder.
Defaults to False.
"""
def __init__(self, encoder_depth, num_classes, num_filters=32, dropout_2d=0.2,
pretrained=False, is_deconv=False):
super().__init__()
self.num_classes = num_classes
self.dropout_2d = dropout_2d
if encoder_depth == 34:
self.encoder = torchvision.models.resnet34(pretrained=pretrained)
bottom_channel_nr = 512
elif encoder_depth == 101:
self.encoder = torchvision.models.resnet101(pretrained=pretrained)
bottom_channel_nr = 2048
elif encoder_depth == 152:
self.encoder = torchvision.models.resnet152(pretrained=pretrained)
bottom_channel_nr = 2048
else:
raise NotImplementedError('only 34, 101, 152 version of Resnet are implemented')
self.pool = nn.MaxPool2d(2, 2)
self.relu = nn.ReLU(inplace=True)
self.conv1 = nn.Sequential(self.encoder.conv1,
self.encoder.bn1,
self.encoder.relu,
self.pool)
self.conv2 = self.encoder.layer1
self.conv3 = self.encoder.layer2
self.conv4 = self.encoder.layer3
self.conv5 = self.encoder.layer4
self.center = DecoderBlockV2(bottom_channel_nr, num_filters * 8 * 2, num_filters * 8, is_deconv)
self.dec5 = DecoderBlockV2(bottom_channel_nr + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
self.dec4 = DecoderBlockV2(bottom_channel_nr // 2 + num_filters * 8, num_filters * 8 * 2, num_filters * 8,
is_deconv)
self.dec3 = DecoderBlockV2(bottom_channel_nr // 4 + num_filters * 8, num_filters * 4 * 2, num_filters * 2,
is_deconv)
self.dec2 = DecoderBlockV2(bottom_channel_nr // 8 + num_filters * 2, num_filters * 2 * 2, num_filters * 2 * 2,
is_deconv)
self.dec1 = DecoderBlockV2(num_filters * 2 * 2, num_filters * 2 * 2, num_filters, is_deconv)
self.dec0 = ConvRelu(num_filters, num_filters)
self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)
def forward(self, x):
conv1 = self.conv1(x)
conv2 = self.conv2(conv1)
conv3 = self.conv3(conv2)
conv4 = self.conv4(conv3)
conv5 = self.conv5(conv4)
pool = self.pool(conv5)
center = self.center(pool)
dec5 = self.dec5(torch.cat([center, conv5], 1))
dec4 = self.dec4(torch.cat([dec5, conv4], 1))
dec3 = self.dec3(torch.cat([dec4, conv3], 1))
dec2 = self.dec2(torch.cat([dec3, conv2], 1))
dec1 = self.dec1(dec2)
dec0 = self.dec0(dec1)
return self.final(F.dropout2d(dec0, p=self.dropout_2d))