Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MobileNetV3 Architecture in TorchVision #3182

Merged
merged 26 commits into from
Jan 5, 2021
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
22b9b12
partial implementation network architecture
datumbox Dec 15, 2020
0cff6fb
Simplify implementation and adding blocks.
datumbox Dec 16, 2020
fd54fdf
Refactoring the code to make it more readable.
datumbox Dec 16, 2020
834b185
Adding first conv layers.
datumbox Dec 16, 2020
1edd16b
Moving mobilenet.py to mobilenetv2.py
datumbox Dec 16, 2020
2f52f0d
Adding mobilenet.py for BC.
datumbox Dec 16, 2020
bb2ec9e
Extending ConvBNReLU for reuse.
datumbox Dec 16, 2020
e95ee5c
Moving mobilenet.py to mobilenetv2.py
datumbox Dec 16, 2020
2ebe8ba
Adding mobilenet.py for BC.
datumbox Dec 16, 2020
0c31a33
Extending ConvBNReLU for reuse.
datumbox Dec 16, 2020
db7522b
Reduce import scope on mobilenet to only the public and versioned cla…
datumbox Dec 16, 2020
fdbcec7
Merge branch 'refactoring/mobilenetv2_bc_move' into models/mobilenetv3
datumbox Dec 16, 2020
16f55f5
Further simplify by reusing MobileNetv2 methods.
datumbox Dec 16, 2020
8162fa4
Adding the remaining implementation of mobilenetv3.
datumbox Dec 16, 2020
8615585
Adding tests, docs and init methods.
datumbox Dec 16, 2020
8664fde
Refactoring and fixing formatter.
datumbox Dec 16, 2020
cfa20b7
Fixing type issues.
datumbox Dec 16, 2020
c189ae1
Using build-in Hardsigmoid and Hardswish.
datumbox Dec 16, 2020
5030435
Merge branch 'master' into models/mobilenetv3
datumbox Dec 17, 2020
9a758a8
Code review nits.
datumbox Dec 17, 2020
25f8b26
Putting inplace on Dropout.
datumbox Dec 17, 2020
5198385
Adding rmsprop support on the train.py
datumbox Jan 1, 2021
e4d130f
Adding auto-augment and random-erase in the training scripts.
datumbox Jan 1, 2021
5d0a664
Merge branch 'master' into models/mobilenetv3
datumbox Jan 3, 2021
c0a13a2
Adding support for reduced tail on MobileNetV3.
datumbox Jan 5, 2021
2414d2d
Tagging blocks with comments.
datumbox Jan 5, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from torchvision.models.vgg import vgg11, vgg13, vgg16, vgg19, vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn
from torchvision.models.googlenet import googlenet
from torchvision.models.shufflenetv2 import shufflenet_v2_x0_5, shufflenet_v2_x1_0
from torchvision.models.mobilenet import mobilenet_v2
from torchvision.models.mobilenetv2 import mobilenet_v2
from torchvision.models.mobilenetv3 import mobilenet_v3_large, mobilenet_v3_small
from torchvision.models.mnasnet import mnasnet0_5, mnasnet0_75, mnasnet1_0, \
mnasnet1_3

Expand Down
Binary file not shown.
Binary file not shown.
17 changes: 9 additions & 8 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,16 +275,17 @@ def test_mobilenetv2_residual_setting(self):
out = model(x)
self.assertEqual(out.shape[-1], 1000)

def test_mobilenetv2_norm_layer(self):
model = models.__dict__["mobilenet_v2"]()
self.assertTrue(any(isinstance(x, nn.BatchNorm2d) for x in model.modules()))
def test_mobilenet_norm_layer(self):
for name in ["mobilenet_v2", "mobilenet_v3_large", "mobilenet_v3_small"]:
model = models.__dict__[name]()
self.assertTrue(any(isinstance(x, nn.BatchNorm2d) for x in model.modules()))

def get_gn(num_channels):
return nn.GroupNorm(32, num_channels)
def get_gn(num_channels):
return nn.GroupNorm(32, num_channels)

model = models.__dict__["mobilenet_v2"](norm_layer=get_gn)
self.assertFalse(any(isinstance(x, nn.BatchNorm2d) for x in model.modules()))
self.assertTrue(any(isinstance(x, nn.GroupNorm) for x in model.modules()))
model = models.__dict__[name](norm_layer=get_gn)
self.assertFalse(any(isinstance(x, nn.BatchNorm2d) for x in model.modules()))
self.assertTrue(any(isinstance(x, nn.GroupNorm) for x in model.modules()))

def test_inceptionv3_eval(self):
# replacement for models.inception_v3(pretrained=True) that does not download weights
Expand Down
209 changes: 2 additions & 207 deletions torchvision/models/mobilenet.py
Original file line number Diff line number Diff line change
@@ -1,207 +1,2 @@
from torch import nn
from torch import Tensor
from .utils import load_state_dict_from_url
from typing import Callable, Any, Optional, List


__all__ = ['MobileNetV2', 'mobilenet_v2']


model_urls = {
'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth',
}


def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int:
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by 8
It can be seen here:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
:param v:
:param divisor:
:param min_value:
:return:
"""
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v


class ConvBNReLU(nn.Sequential):
def __init__(
self,
in_planes: int,
out_planes: int,
kernel_size: int = 3,
stride: int = 1,
groups: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None
) -> None:
padding = (kernel_size - 1) // 2
if norm_layer is None:
norm_layer = nn.BatchNorm2d
super(ConvBNReLU, self).__init__(
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
norm_layer(out_planes),
nn.ReLU6(inplace=True)
)


class InvertedResidual(nn.Module):
def __init__(
self,
inp: int,
oup: int,
stride: int,
expand_ratio: int,
norm_layer: Optional[Callable[..., nn.Module]] = None
) -> None:
super(InvertedResidual, self).__init__()
self.stride = stride
assert stride in [1, 2]

if norm_layer is None:
norm_layer = nn.BatchNorm2d

hidden_dim = int(round(inp * expand_ratio))
self.use_res_connect = self.stride == 1 and inp == oup

layers: List[nn.Module] = []
if expand_ratio != 1:
# pw
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer))
layers.extend([
# dw
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
norm_layer(oup),
])
self.conv = nn.Sequential(*layers)

def forward(self, x: Tensor) -> Tensor:
if self.use_res_connect:
return x + self.conv(x)
else:
return self.conv(x)


class MobileNetV2(nn.Module):
def __init__(
self,
num_classes: int = 1000,
width_mult: float = 1.0,
inverted_residual_setting: Optional[List[List[int]]] = None,
round_nearest: int = 8,
block: Optional[Callable[..., nn.Module]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None
) -> None:
"""
MobileNet V2 main class

Args:
num_classes (int): Number of classes
width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
inverted_residual_setting: Network structure
round_nearest (int): Round the number of channels in each layer to be a multiple of this number
Set to 1 to turn off rounding
block: Module specifying inverted residual building block for mobilenet
norm_layer: Module specifying the normalization layer to use

"""
super(MobileNetV2, self).__init__()

if block is None:
block = InvertedResidual

if norm_layer is None:
norm_layer = nn.BatchNorm2d

input_channel = 32
last_channel = 1280

if inverted_residual_setting is None:
inverted_residual_setting = [
# t, c, n, s
[1, 16, 1, 1],
[6, 24, 2, 2],
[6, 32, 3, 2],
[6, 64, 4, 2],
[6, 96, 3, 1],
[6, 160, 3, 2],
[6, 320, 1, 1],
]

# only check the first element, assuming user knows t,c,n,s are required
if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
raise ValueError("inverted_residual_setting should be non-empty "
"or a 4-element list, got {}".format(inverted_residual_setting))

# building first layer
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
features: List[nn.Module] = [ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer)]
# building inverted residual blocks
for t, c, n, s in inverted_residual_setting:
output_channel = _make_divisible(c * width_mult, round_nearest)
for i in range(n):
stride = s if i == 0 else 1
features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer))
input_channel = output_channel
# building last several layers
features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer))
# make it nn.Sequential
self.features = nn.Sequential(*features)

# building classifier
self.classifier = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(self.last_channel, num_classes),
)

# weight initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.zeros_(m.bias)

def _forward_impl(self, x: Tensor) -> Tensor:
# This exists since TorchScript doesn't support inheritance, so the superclass method
# (this one) needs to have a name other than `forward` that can be accessed in a subclass
x = self.features(x)
# Cannot use "squeeze" as batch-size can be 1 => must use reshape with x.shape[0]
x = nn.functional.adaptive_avg_pool2d(x, (1, 1)).reshape(x.shape[0], -1)
x = self.classifier(x)
return x

def forward(self, x: Tensor) -> Tensor:
return self._forward_impl(x)


def mobilenet_v2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV2:
"""
Constructs a MobileNetV2 architecture from
`"MobileNetV2: Inverted Residuals and Linear Bottlenecks" <https://arxiv.org/abs/1801.04381>`_.

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
model = MobileNetV2(**kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'],
progress=progress)
model.load_state_dict(state_dict)
return model
from .mobilenetv2 import MobileNetV2, mobilenet_v2
from .mobilenetv3 import MobileNetV3, mobilenet_v3_large, mobilenet_v3_small
Loading