Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MobileNetV3 architecture for Segmentation #3276

Merged
merged 12 commits into from
Jan 27, 2021
12 changes: 11 additions & 1 deletion docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,8 @@ The models subpackage contains definitions for the following model
architectures for semantic segmentation:

- `FCN ResNet50, ResNet101 <https://arxiv.org/abs/1411.4038>`_
- `DeepLabV3 ResNet50, ResNet101 <https://arxiv.org/abs/1706.05587>`_
- `DeepLabV3 ResNet50, ResNet101, MobileNetV3-Large <https://arxiv.org/abs/1706.05587>`_
- `LR-ASPP MobileNetV3-Large <https://arxiv.org/abs/1905.02244>`_

As with image classification models, all pre-trained models expect input images normalized in the same way.
The images have to be loaded in to a range of ``[0, 1]`` and then normalized using
Expand All @@ -298,6 +299,8 @@ FCN ResNet50 60.5 91.4
FCN ResNet101 63.7 91.9
DeepLabV3 ResNet50 66.4 92.4
DeepLabV3 ResNet101 67.4 92.4
DeepLabV3 MobileNetV3-Large 60.3 91.2
LR-ASPP MobileNetV3-Large 57.9 91.2
================================ ============= ====================


Expand All @@ -313,6 +316,13 @@ DeepLabV3

.. autofunction:: torchvision.models.segmentation.deeplabv3_resnet50
.. autofunction:: torchvision.models.segmentation.deeplabv3_resnet101
.. autofunction:: torchvision.models.segmentation.deeplabv3_mobilenet_v3_large


LR-ASPP
-------

.. autofunction:: torchvision.models.segmentation.lraspp_mobilenet_v3_large


Object Detection, Instance Segmentation and Person Keypoint Detection
Expand Down
2 changes: 1 addition & 1 deletion hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@

# segmentation
from torchvision.models.segmentation import fcn_resnet50, fcn_resnet101, \
deeplabv3_resnet50, deeplabv3_resnet101
deeplabv3_resnet50, deeplabv3_resnet101, deeplabv3_mobilenet_v3_large, lraspp_mobilenet_v3_large
10 changes: 10 additions & 0 deletions references/segmentation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,13 @@ python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py --lr 0.
```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py --lr 0.02 --dataset coco -b 4 --model deeplabv3_resnet101 --aux-loss
```

## deeplabv3_mobilenet_v3_large
```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py --dataset coco -b 4 --model deeplabv3_mobilenet_v3_large --aux-loss --wd 0.000001
```

## lraspp_mobilenet_v3_large
```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py --dataset coco -b 4 --model lraspp_mobilenet_v3_large --wd 0.000001
```
Binary file not shown.
Binary file not shown.
2 changes: 2 additions & 0 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,10 @@ def get_available_video_models():
"wide_resnet101_2",
"deeplabv3_resnet50",
"deeplabv3_resnet101",
"deeplabv3_mobilenet_v3_large",
"fcn_resnet50",
"fcn_resnet101",
"lraspp_mobilenet_v3_large",
)


Expand Down
4 changes: 2 additions & 2 deletions torchvision/models/detection/backbone_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,9 @@ def mobilenet_backbone(
):
backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer).features

# Gather the indeces of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "is_strided", False)] + [len(backbone) - 1]
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
num_stages = len(stage_indices)

# find the index of the layer from which we wont freeze
Expand Down
8 changes: 5 additions & 3 deletions torchvision/models/mobilenetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,16 @@ def __init__(
groups: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None,
activation_layer: Optional[Callable[..., nn.Module]] = None,
dilation: int = 1,
) -> None:
padding = (kernel_size - 1) // 2
padding = (kernel_size - 1) // 2 * dilation
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if activation_layer is None:
activation_layer = nn.ReLU6
super(ConvBNReLU, self).__init__(
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, dilation=dilation, groups=groups,
bias=False),
norm_layer(out_planes),
activation_layer(inplace=True)
)
Expand Down Expand Up @@ -88,7 +90,7 @@ def __init__(
])
self.conv = nn.Sequential(*layers)
self.out_channels = oup
self.is_strided = stride > 1
self._is_cn = stride > 1

def forward(self, x: Tensor) -> Tensor:
if self.use_res_connect:
Expand Down
86 changes: 42 additions & 44 deletions torchvision/models/mobilenetv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,15 @@ def forward(self, input: Tensor) -> Tensor:
class InvertedResidualConfig:

def __init__(self, input_channels: int, kernel: int, expanded_channels: int, out_channels: int, use_se: bool,
activation: str, stride: int, width_mult: float):
activation: str, stride: int, dilation: int, width_mult: float):
self.input_channels = self.adjust_channels(input_channels, width_mult)
self.kernel = kernel
self.expanded_channels = self.adjust_channels(expanded_channels, width_mult)
self.out_channels = self.adjust_channels(out_channels, width_mult)
self.use_se = use_se
self.use_hs = activation == "HS"
self.stride = stride
self.dilation = dilation

@staticmethod
def adjust_channels(channels: int, width_mult: float):
Expand All @@ -70,9 +71,10 @@ def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Mod
norm_layer=norm_layer, activation_layer=activation_layer))

# depthwise
stride = 1 if cnf.dilation > 1 else cnf.stride
layers.append(ConvBNActivation(cnf.expanded_channels, cnf.expanded_channels, kernel_size=cnf.kernel,
stride=cnf.stride, groups=cnf.expanded_channels, norm_layer=norm_layer,
activation_layer=activation_layer))
stride=stride, dilation=cnf.dilation, groups=cnf.expanded_channels,
norm_layer=norm_layer, activation_layer=activation_layer))
if cnf.use_se:
layers.append(SqueezeExcitation(cnf.expanded_channels))

Expand All @@ -82,7 +84,7 @@ def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Mod

self.block = nn.Sequential(*layers)
self.out_channels = cnf.out_channels
self.is_strided = cnf.stride > 1
self._is_cn = cnf.stride > 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out of curiosity, what does cn mean in here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's from the C0,C1...C5,Cn names used in Object Detection. I use this feature internally to find out where the downsampling was supposed to happen but it's not always done with strides so I had to rename it. If you have any better name for it, happy to change it. I could not think of any...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the explanation. Given that this is private I'm fine with this name


def forward(self, input: Tensor) -> Tensor:
result = self.block(input)
Expand Down Expand Up @@ -194,78 +196,74 @@ def _mobilenet_v3(
return model


def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, reduced_tail: bool = False,
**kwargs: Any) -> MobileNetV3:
def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3:
"""
Constructs a large MobileNetV3 architecture from
`"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_.

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
reduced_tail (bool): If True, reduces the channel counts of all feature layers
between C4 and C5 by 2. It is used to reduce the channel redundancy in the
backbone for Detection and Segmentation.
"""
# non-public config parameters
reduce_divider = 2 if kwargs.pop('_reduced_tail', False) else 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this feature used in any of the models? Otherwise we can just remove it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a unique implementation detail from the paper on MobileNetV3 models and it's supposed to produce a further speed optimization on object detection and segmentation. In our training scripts we don't use it because we do transfer learning from ImageNet but if someone really wants to train it from scratch and go smaller I provide a way to do it.

On current master this is public (see reduced_tail param) but here I decide to hide before the release and make it an internal implementation detail for future models. Not quite convinced we will use it but want to provide an implementation very close to the paper.

Personally I would prefer to keep it hidden for now and decide later whether we want this gone. Let me know.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, I'm ok keeping this private for now and maybe removing it from the future.

dilation = 2 if kwargs.pop('_dilated', False) else 1
width_mult = 1.0

bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult)
adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult)

reduce_divider = 2 if reduced_tail else 1

inverted_residual_setting = [
bneck_conf(16, 3, 16, 16, False, "RE", 1),
bneck_conf(16, 3, 64, 24, False, "RE", 2), # C1
bneck_conf(24, 3, 72, 24, False, "RE", 1),
bneck_conf(24, 5, 72, 40, True, "RE", 2), # C2
bneck_conf(40, 5, 120, 40, True, "RE", 1),
bneck_conf(40, 5, 120, 40, True, "RE", 1),
bneck_conf(40, 3, 240, 80, False, "HS", 2), # C3
bneck_conf(80, 3, 200, 80, False, "HS", 1),
bneck_conf(80, 3, 184, 80, False, "HS", 1),
bneck_conf(80, 3, 184, 80, False, "HS", 1),
bneck_conf(80, 3, 480, 112, True, "HS", 1),
bneck_conf(112, 3, 672, 112, True, "HS", 1),
bneck_conf(112, 5, 672, 160 // reduce_divider, True, "HS", 2), # C4
bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1),
bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1),
bneck_conf(16, 3, 16, 16, False, "RE", 1, 1),
bneck_conf(16, 3, 64, 24, False, "RE", 2, 1), # C1
bneck_conf(24, 3, 72, 24, False, "RE", 1, 1),
bneck_conf(24, 5, 72, 40, True, "RE", 2, 1), # C2
bneck_conf(40, 5, 120, 40, True, "RE", 1, 1),
bneck_conf(40, 5, 120, 40, True, "RE", 1, 1),
bneck_conf(40, 3, 240, 80, False, "HS", 2, 1), # C3
bneck_conf(80, 3, 200, 80, False, "HS", 1, 1),
bneck_conf(80, 3, 184, 80, False, "HS", 1, 1),
bneck_conf(80, 3, 184, 80, False, "HS", 1, 1),
bneck_conf(80, 3, 480, 112, True, "HS", 1, 1),
bneck_conf(112, 3, 672, 112, True, "HS", 1, 1),
bneck_conf(112, 5, 672, 160 // reduce_divider, True, "HS", 2, dilation), # C4
bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1, dilation),
bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1, dilation),
]
last_channel = adjust_channels(1280 // reduce_divider) # C5

return _mobilenet_v3("mobilenet_v3_large", inverted_residual_setting, last_channel, pretrained, progress, **kwargs)


def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, reduced_tail: bool = False,
**kwargs: Any) -> MobileNetV3:
def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3:
"""
Constructs a small MobileNetV3 architecture from
`"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_.

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
reduced_tail (bool): If True, reduces the channel counts of all feature layers
between C4 and C5 by 2. It is used to reduce the channel redundancy in the
backbone for Detection and Segmentation.
"""
# non-public config parameters
reduce_divider = 2 if kwargs.pop('_reduced_tail', False) else 1
dilation = 2 if kwargs.pop('_dilated', False) else 1
width_mult = 1.0

bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult)
adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult)

reduce_divider = 2 if reduced_tail else 1

inverted_residual_setting = [
bneck_conf(16, 3, 16, 16, True, "RE", 2), # C1
bneck_conf(16, 3, 72, 24, False, "RE", 2), # C2
bneck_conf(24, 3, 88, 24, False, "RE", 1),
bneck_conf(24, 5, 96, 40, True, "HS", 2), # C3
bneck_conf(40, 5, 240, 40, True, "HS", 1),
bneck_conf(40, 5, 240, 40, True, "HS", 1),
bneck_conf(40, 5, 120, 48, True, "HS", 1),
bneck_conf(48, 5, 144, 48, True, "HS", 1),
bneck_conf(48, 5, 288, 96 // reduce_divider, True, "HS", 2), # C4
bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1),
bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1),
bneck_conf(16, 3, 16, 16, True, "RE", 2, 1), # C1
bneck_conf(16, 3, 72, 24, False, "RE", 2, 1), # C2
bneck_conf(24, 3, 88, 24, False, "RE", 1, 1),
bneck_conf(24, 5, 96, 40, True, "HS", 2, 1), # C3
bneck_conf(40, 5, 240, 40, True, "HS", 1, 1),
bneck_conf(40, 5, 240, 40, True, "HS", 1, 1),
bneck_conf(40, 5, 120, 48, True, "HS", 1, 1),
bneck_conf(48, 5, 144, 48, True, "HS", 1, 1),
bneck_conf(48, 5, 288, 96 // reduce_divider, True, "HS", 2, dilation), # C4
bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1, dilation),
bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1, dilation),
]
last_channel = adjust_channels(1024 // reduce_divider) # C5

Expand Down
1 change: 1 addition & 0 deletions torchvision/models/segmentation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .segmentation import *
from .fcn import *
from .deeplabv3 import *
from .lraspp import *
1 change: 0 additions & 1 deletion torchvision/models/segmentation/_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from collections import OrderedDict

import torch
from torch import nn
from torch.nn import functional as F

Expand Down
69 changes: 69 additions & 0 deletions torchvision/models/segmentation/lraspp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from collections import OrderedDict

from torch import nn, Tensor
from torch.nn import functional as F
from typing import Dict


__all__ = ["LRASPP"]


class LRASPP(nn.Module):
"""
Implements a Lite R-ASPP Network for semantic segmentation from
`"Searching for MobileNetV3"
<https://arxiv.org/abs/1905.02244>`_.

Args:
backbone (nn.Module): the network used to compute the features for the model.
The backbone should return an OrderedDict[Tensor], with the key being
"high" for the high level feature map and "low" for the low level feature map.
low_channels (int): the number of channels of the low level features.
high_channels (int): the number of channels of the high level features.
num_classes (int): number of output classes of the model (including the background).
inter_channels (int, optional): the number of channels for intermediate computations.
"""

def __init__(self, backbone, low_channels, high_channels, num_classes, inter_channels=128):
super().__init__()
self.backbone = backbone
self.classifier = LRASPPHead(low_channels, high_channels, num_classes, inter_channels)

def forward(self, input):
features = self.backbone(input)
out = self.classifier(features)
out = F.interpolate(out, size=input.shape[-2:], mode='bilinear', align_corners=False)

result = OrderedDict()
result["out"] = out

return result


class LRASPPHead(nn.Module):

def __init__(self, low_channels, high_channels, num_classes, inter_channels):
super().__init__()
self.cbr = nn.Sequential(
nn.Conv2d(high_channels, inter_channels, 1, bias=False),
nn.BatchNorm2d(inter_channels),
nn.ReLU(inplace=True)
)
self.scale = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(high_channels, inter_channels, 1, bias=False),
nn.Sigmoid(),
)
self.low_classifier = nn.Conv2d(low_channels, num_classes, 1)
self.high_classifier = nn.Conv2d(inter_channels, num_classes, 1)

def forward(self, input: Dict[str, Tensor]) -> Tensor:
low = input["low"]
high = input["high"]

x = self.cbr(high)
s = self.scale(high)
x = x * s
x = F.interpolate(x, size=low.shape[-2:], mode='bilinear', align_corners=False)

return self.low_classifier(low) + self.high_classifier(x)
Loading