-
Notifications
You must be signed in to change notification settings - Fork 0
/
backbone.py
81 lines (71 loc) · 3.71 KB
/
backbone.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# Author: Zylo117
# coding=utf-8
import torch
from torch import nn
from efficientdet.model import BiFPN, Regressor, Classifier, EfficientNet
from efficientdet.utils import Anchors
class EfficientDetBackbone(nn.Module):
def __init__(self, num_classes=80, compound_coef= 2, load_weights=False, **kwargs):
super(EfficientDetBackbone, self).__init__()
self.compound_coef = compound_coef
self.backbone_compound_coef = [0, 1, 2, 3, 4, 5, 6, 6, 7]
self.fpn_num_filters = [64, 88, 112, 160, 224, 288, 384, 384, 384]
self.fpn_cell_repeats = [3, 4, 5, 6, 7, 7, 8, 8, 8]
self.input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536, 1536]
self.box_class_repeats = [3, 3, 3, 4, 4, 4, 5, 5, 5]
self.pyramid_levels = [5, 5, 4, 5, 5, 5, 5, 5, 6]
self.anchor_scale = [4., 4., 2., 4., 4., 4., 4., 5., 4.]
self.aspect_ratios = kwargs.get('ratios', [ (1.0, 1.0), (2.0, 1.0), (1.0, 2.0) ] )
self.num_scales = len(kwargs.get('scales', [ 2 ** (0) , 2 ** (1.0 / 2.0) ] ))
self.anchors_rotations = len(kwargs.get('rotations', [ -90.0 ,-75.0 ,-60.0, -45.0, -30.0, -15.0 ] ))
conv_channel_coef = {
# the channels of P3/P4/P5.
0: [40, 112, 320],
1: [40, 112, 320],
2: [48, 120, 352],
3: [48, 136, 384],
4: [56, 160, 448],
5: [64, 176, 512],
6: [72, 200, 576],
7: [72, 200, 576],
8: [80, 224, 640],
}
num_anchors = len(self.aspect_ratios) * self.anchors_rotations * self.num_scales
self.bifpn = nn.Sequential(
*[BiFPN(self.fpn_num_filters[self.compound_coef],
conv_channel_coef[compound_coef],
True if _ == 0 else False,
attention=True if compound_coef < 6 else False,
use_p8=compound_coef > 7)
for _ in range(self.fpn_cell_repeats[compound_coef])])
self.num_classes = num_classes
self.regressor = Regressor(in_channels=self.fpn_num_filters[self.compound_coef], num_anchors=num_anchors,
num_layers=self.box_class_repeats[self.compound_coef],
pyramid_levels=self.pyramid_levels[self.compound_coef])
self.classifier = Classifier(in_channels=self.fpn_num_filters[self.compound_coef], num_anchors=num_anchors,
num_classes=num_classes,
num_layers=self.box_class_repeats[self.compound_coef],
pyramid_levels=self.pyramid_levels[self.compound_coef])
self.anchors = Anchors(anchor_scale=self.anchor_scale[compound_coef],
pyramid_levels=(torch.arange(self.pyramid_levels[self.compound_coef]) + 3).tolist(), **kwargs)
self.backbone_net = EfficientNet(self.backbone_compound_coef[compound_coef], load_weights)
def freeze_bn(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
def forward(self, inputs):
max_size = inputs.shape[-1]
_, p3, p4, p5 = self.backbone_net(inputs)
features = (p3, p4, p5)
features = self.bifpn(features)
regression = self.regressor(features)
classification = self.classifier(features)
anchors = self.anchors(inputs, inputs.dtype)
return features, regression, classification, anchors
def init_backbone(self, path):
state_dict = torch.load(path)
try:
ret = self.load_state_dict(state_dict, strict=False)
print(ret)
except RuntimeError as e:
print('Ignoring ' + str(e) + '"')