Skip to content

Commit

Permalink
Make YOLOv3 neck more flexible (#5218)
Browse files Browse the repository at this point in the history
* Make YOLOv3 neck more flexible

* Fix linter warning

* update docstring

* add unit test of yolov3 neck

* update

* fix docstring

Co-authored-by: RangiLyu <[email protected]>
  • Loading branch information
hokmund and RangiLyu authored Jun 16, 2021
1 parent 48cfeb0 commit 87ad480
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 8 deletions.
16 changes: 9 additions & 7 deletions mmdet/models/necks/yolo_neck.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,13 @@ class YOLOV3Neck(BaseModule):
Args:
num_scales (int): The number of scales / stages.
in_channels (int): The number of input channels.
out_channels (int): The number of output channels.
conv_cfg (dict): Config dict for convolution layer. Default: None.
norm_cfg (dict): Dictionary to construct and config norm layer.
Default: dict(type='BN', requires_grad=True)
act_cfg (dict): Config dict for activation layer.
in_channels (List[int]): The number of input channels per scale.
out_channels (List[int]): The number of output channels per scale.
conv_cfg (dict, optional): Config dict for convolution layer.
Default: None.
norm_cfg (dict, optional): Dictionary to construct and config norm
layer. Default: dict(type='BN', requires_grad=True)
act_cfg (dict, optional): Config dict for activation layer.
Default: dict(type='LeakyReLU', negative_slope=0.1).
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None
Expand Down Expand Up @@ -109,7 +110,8 @@ def __init__(self,
self.detect1 = DetectionBlock(in_channels[0], out_channels[0], **cfg)
for i in range(1, self.num_scales):
in_c, out_c = self.in_channels[i], self.out_channels[i]
self.add_module(f'conv{i}', ConvModule(in_c, out_c, 1, **cfg))
inter_c = out_channels[i - 1]
self.add_module(f'conv{i}', ConvModule(inter_c, out_c, 1, **cfg))
# in_c + out_c : High-lvl feats will be cat with low-lvl feats
self.add_module(f'detect{i+1}',
DetectionBlock(in_c + out_c, out_c, **cfg))
Expand Down
52 changes: 51 additions & 1 deletion tests/test_models/test_necks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import torch
from torch.nn.modules.batchnorm import _BatchNorm

from mmdet.models.necks import FPN, ChannelMapper, CTResNetNeck, DilatedEncoder
from mmdet.models.necks import (FPN, ChannelMapper, CTResNetNeck,
DilatedEncoder, YOLOV3Neck)


def test_fpn():
Expand Down Expand Up @@ -288,3 +289,52 @@ def test_ct_resnet_neck():
feat = feat.cuda()
out_feat = ct_resnet_neck([feat])[0]
assert out_feat.shape == (1, num_filters[-1], 16, 16)


def test_yolov3_neck():
# num_scales, in_channels, out_channels must be same length
with pytest.raises(AssertionError):
YOLOV3Neck(num_scales=3, in_channels=[16, 8, 4], out_channels=[8, 4])

# len(feats) must equal to num_scales
with pytest.raises(AssertionError):
neck = YOLOV3Neck(
num_scales=3, in_channels=[16, 8, 4], out_channels=[8, 4, 2])
feats = (torch.rand(1, 4, 16, 16), torch.rand(1, 8, 16, 16))
neck(feats)

# test normal channels
s = 32
in_channels = [16, 8, 4]
out_channels = [8, 4, 2]
feat_sizes = [s // 2**i for i in range(len(in_channels) - 1, -1, -1)]
feats = [
torch.rand(1, in_channels[i], feat_sizes[i], feat_sizes[i])
for i in range(len(in_channels) - 1, -1, -1)
]
neck = YOLOV3Neck(
num_scales=3, in_channels=in_channels, out_channels=out_channels)
outs = neck(feats)

assert len(outs) == len(feats)
for i in range(len(outs)):
assert outs[i].shape == \
(1, out_channels[i], feat_sizes[i], feat_sizes[i])

# test more flexible setting
s = 32
in_channels = [32, 8, 16]
out_channels = [19, 21, 5]
feat_sizes = [s // 2**i for i in range(len(in_channels) - 1, -1, -1)]
feats = [
torch.rand(1, in_channels[i], feat_sizes[i], feat_sizes[i])
for i in range(len(in_channels) - 1, -1, -1)
]
neck = YOLOV3Neck(
num_scales=3, in_channels=in_channels, out_channels=out_channels)
outs = neck(feats)

assert len(outs) == len(feats)
for i in range(len(outs)):
assert outs[i].shape == \
(1, out_channels[i], feat_sizes[i], feat_sizes[i])

0 comments on commit 87ad480

Please sign in to comment.