-
Notifications
You must be signed in to change notification settings - Fork 2
/
deeplabv3_plus.py
153 lines (109 loc) · 4.54 KB
/
deeplabv3_plus.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
from typing import Any
import torch
import torch.nn as nn
from cvm import models
from ..ops import blocks
from ..utils import export, get_out_channels, load_from_local_or_url
from torch.nn import functional as F
from .heads import FCNHead, ClsHead
from .segmentation_model import SegmentationModel
class DeepLabPlusHead(nn.Module):
def __init__(
self,
aspp_in_channels: int,
feautes_channels: int,
out_channels: int = 256,
num_classes: int = 32,
):
super().__init__()
self.aspp = blocks.ASPP(aspp_in_channels, out_channels, [12, 24, 36])
self.cat = blocks.Combine('CONCAT')
self.conv3x3 = blocks.Conv2d3x3(out_channels + feautes_channels, num_classes)
def forward(self, x, low_level_feautes):
size = low_level_feautes.shape[-2:]
aspp_features = self.aspp(x)
aspp_features = F.interpolate(aspp_features, size=size, mode="bilinear", align_corners=False)
features = self.cat([aspp_features, low_level_feautes])
features = self.conv3x3(features)
return features
@export
class DeepLabV3Plus(SegmentationModel):
def forward(self, x):
size = x.shape[-2:]
stages = self.backbone(x)
out = self.decode_head(stages[f'stage{self.out_stages[-1]}'], stages[f'stage{self.out_stages[0]}'], )
out = self.interpolate(out, size=size)
res = {'out': out}
if self.aux_head:
aux = self.aux_head(stages[f'stage{self.out_stages[-2]}'])
aux = self.interpolate(aux, size=size)
res['aux'] = aux
if self.cls_head:
cls = self.cls_head(stages[f'stage{self.out_stages[-1]}'])
cls = cls.reshape(cls.shape[0], cls.shape[1], 1, 1)
res['out'] = out * torch.sigmoid(cls)
return res
@export
def create_deeplabv3_plus(
backbone: str = 'resnet50_v1',
num_classes: int = 21,
aux_loss: bool = False,
cls_loss: bool = False,
dropout_rate: float = 0.1,
pretrained_backbone: bool = False,
pretrained: bool = False,
pth: str = None,
progress: bool = True,
**kwargs: Any
):
if pretrained:
pretrained_backbone = False
backbone = models.__dict__[backbone](
pretrained=pretrained_backbone,
dilations=[1, 1, 2, 4],
**kwargs
).features
aux_head = FCNHead(get_out_channels(backbone.stage3), None, num_classes, dropout_rate) if aux_loss else None
cls_head = ClsHead(get_out_channels(backbone.stage4), num_classes) if cls_loss else None
decode_head = DeepLabPlusHead(get_out_channels(backbone.stage4),
get_out_channels(backbone.stage2), num_classes=num_classes)
model = DeepLabV3Plus(backbone, [2, 3, 4] if aux_loss else [2, 4], decode_head, aux_head, cls_head)
if pretrained:
load_from_local_or_url(model, pth, kwargs.get('url', None), progress)
return model
@export
def deeplabv3_plus_resnet50_v1(*args, **kwargs: Any):
return create_deeplabv3_plus('resnet50_v1', *args, **kwargs)
@export
def deeplabv3_plus_mobilenet_v3_small(*args, **kwargs: Any):
return create_deeplabv3_plus('mobilenet_v3_small', *args, **kwargs)
@export
def deeplabv3_plus_mobilenet_v3_large(*args, **kwargs: Any):
return create_deeplabv3_plus('mobilenet_v3_large', *args, **kwargs)
@export
def deeplabv3_plus_regnet_x_400mf(*args, **kwargs: Any):
return create_deeplabv3_plus('regnet_x_400mf', *args, **kwargs)
@export
def deeplabv3_plus_mobilenet_v1_x1_0(*args, **kwargs: Any):
return create_deeplabv3_plus('mobilenet_v1_x1_0', *args, **kwargs)
@export
def deeplabv3_plus_sd_mobilenet_v1_x1_0(*args, **kwargs: Any):
return create_deeplabv3_plus('sd_mobilenet_v1_x1_0', *args, **kwargs)
@export
def deeplabv3_plus_mobilenet_v2_x1_0(*args, **kwargs: Any):
return create_deeplabv3_plus('mobilenet_v2_x1_0', *args, **kwargs)
@export
def deeplabv3_plus_sd_mobilenet_v2_x1_0(*args, **kwargs: Any):
return create_deeplabv3_plus('sd_mobilenet_v2_x1_0', *args, **kwargs)
@export
def deeplabv3_plus_shufflenet_v2_x2_0(*args, **kwargs: Any):
return create_deeplabv3_plus('shufflenet_v2_x2_0', *args, **kwargs)
@export
def deeplabv3_plus_sd_shufflenet_v2_x2_0(*args, **kwargs: Any):
return create_deeplabv3_plus('sd_shufflenet_v2_x2_0', *args, **kwargs)
@export
def deeplabv3_plus_efficientnet_b0(*args, **kwargs: Any):
return create_deeplabv3_plus('efficientnet_b0', *args, **kwargs)
@export
def deeplabv3_plus_sd_efficientnet_b0(*args, **kwargs: Any):
return create_deeplabv3_plus('sd_efficientnet_b0', *args, **kwargs)