Skip to content

Commit

Permalink
Adding support for reduced tail on MobileNetV3.
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Jan 4, 2021
1 parent 5d0a664 commit 403396c
Showing 1 changed file with 22 additions and 10 deletions.
32 changes: 22 additions & 10 deletions torchvision/models/mobilenetv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,19 +201,25 @@ def _mobilenet_v3(
return model


def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3:
def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, reduced_tail: bool = False,
**kwargs: Any) -> MobileNetV3:
"""
Constructs a large MobileNetV3 architecture from
`"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
reduced_tail (bool): If True, reduces the channel counts of all feature layers
between C4 and C5 by 2. It is used to reduce the channel redundancy in the
backbone for Detection and Segmentation.
"""
width_mult = 1.0
bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult)
adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult)

reduce_divider = 2 if reduced_tail else 1

inverted_residual_setting = [
bneck_conf(16, 3, 16, 16, False, "RE", 1),
bneck_conf(16, 3, 64, 24, False, "RE", 2),
Expand All @@ -227,28 +233,34 @@ def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, **kwargs
bneck_conf(80, 3, 184, 80, False, "HS", 1),
bneck_conf(80, 3, 480, 112, True, "HS", 1),
bneck_conf(112, 3, 672, 112, True, "HS", 1),
bneck_conf(112, 5, 672, 160, True, "HS", 2),
bneck_conf(160, 5, 960, 160, True, "HS", 1),
bneck_conf(160, 5, 960, 160, True, "HS", 1),
bneck_conf(112, 5, 672, 160, True, "HS", 2), # C4
bneck_conf(160 // reduce_divider, 5, 960, 160, True, "HS", 1),
bneck_conf(160 // reduce_divider, 5, 960, 160, True, "HS", 1),
]
last_channel = adjust_channels(1280)
last_channel = adjust_channels(1280 // reduce_divider) # C5

return _mobilenet_v3("mobilenet_v3_large", inverted_residual_setting, last_channel, pretrained, progress, **kwargs)


def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3:
def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, reduced_tail: bool = False,
**kwargs: Any) -> MobileNetV3:
"""
Constructs a small MobileNetV3 architecture from
`"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
reduced_tail (bool): If True, reduces the channel counts of all feature layers
between C4 and C5 by 2. It is used to reduce the channel redundancy in the
backbone for Detection and Segmentation.
"""
width_mult = 1.0
bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult)
adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult)

reduce_divider = 2 if reduced_tail else 1

inverted_residual_setting = [
bneck_conf(16, 3, 16, 16, True, "RE", 2),
bneck_conf(16, 3, 72, 24, False, "RE", 2),
Expand All @@ -258,10 +270,10 @@ def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, **kwargs
bneck_conf(40, 5, 240, 40, True, "HS", 1),
bneck_conf(40, 5, 120, 48, True, "HS", 1),
bneck_conf(48, 5, 144, 48, True, "HS", 1),
bneck_conf(48, 5, 288, 96, True, "HS", 2),
bneck_conf(96, 5, 576, 96, True, "HS", 1),
bneck_conf(96, 5, 576, 96, True, "HS", 1),
bneck_conf(48, 5, 288, 96, True, "HS", 2), # C4
bneck_conf(96 // reduce_divider, 5, 576, 96, True, "HS", 1),
bneck_conf(96 // reduce_divider, 5, 576, 96, True, "HS", 1),
]
last_channel = adjust_channels(1024)
last_channel = adjust_channels(1024 // reduce_divider) # C5

return _mobilenet_v3("mobilenet_v3_small", inverted_residual_setting, last_channel, pretrained, progress, **kwargs)

0 comments on commit 403396c

Please sign in to comment.