diff --git a/docs/source/models.rst b/docs/source/models.rst index b9bff7a36e8..64ca69f47ae 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -27,6 +27,7 @@ architectures for image classification: - `ResNeXt`_ - `Wide ResNet`_ - `MNASNet`_ +- `EfficientNet`_ You can construct a model with random weights by calling its constructor: @@ -47,6 +48,14 @@ You can construct a model with random weights by calling its constructor: resnext50_32x4d = models.resnext50_32x4d() wide_resnet50_2 = models.wide_resnet50_2() mnasnet = models.mnasnet1_0() + efficientnet_b0 = models.efficientnet_b0() + efficientnet_b1 = models.efficientnet_b1() + efficientnet_b2 = models.efficientnet_b2() + efficientnet_b3 = models.efficientnet_b3() + efficientnet_b4 = models.efficientnet_b4() + efficientnet_b5 = models.efficientnet_b5() + efficientnet_b6 = models.efficientnet_b6() + efficientnet_b7 = models.efficientnet_b7() We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`. These can be constructed by passing ``pretrained=True``: @@ -68,6 +77,14 @@ These can be constructed by passing ``pretrained=True``: resnext50_32x4d = models.resnext50_32x4d(pretrained=True) wide_resnet50_2 = models.wide_resnet50_2(pretrained=True) mnasnet = models.mnasnet1_0(pretrained=True) + efficientnet_b0 = models.efficientnet_b0(pretrained=True) + efficientnet_b1 = models.efficientnet_b1(pretrained=True) + efficientnet_b2 = models.efficientnet_b2(pretrained=True) + efficientnet_b3 = models.efficientnet_b3(pretrained=True) + efficientnet_b4 = models.efficientnet_b4(pretrained=True) + efficientnet_b5 = models.efficientnet_b5(pretrained=True) + efficientnet_b6 = models.efficientnet_b6(pretrained=True) + efficientnet_b7 = models.efficientnet_b7(pretrained=True) Instancing a pre-trained model will download its weights to a cache directory. This directory can be set using the `TORCH_MODEL_ZOO` environment variable. See @@ -113,7 +130,10 @@ Unfortunately, the concrete `subset` that was used is lost. For more information see `this discussion `_ or `these experiments `_. -ImageNet 1-crop error rates (224x224) +The sizes of the EfficientNet models depend on the variant. For the exact input sizes +`check here `_ + +ImageNet 1-crop error rates ================================ ============= ============= Model Acc@1 Acc@5 @@ -151,6 +171,14 @@ Wide ResNet-50-2 78.468 94.086 Wide ResNet-101-2 78.848 94.284 MNASNet 1.0 73.456 91.510 MNASNet 0.5 67.734 87.490 +EfficientNet-B0 77.692 93.532 +EfficientNet-B1 78.642 94.186 +EfficientNet-B2 80.608 95.310 +EfficientNet-B3 82.008 96.054 +EfficientNet-B4 83.384 96.594 +EfficientNet-B5 83.444 96.628 +EfficientNet-B6 84.008 96.916 +EfficientNet-B7 84.122 96.908 ================================ ============= ============= @@ -166,6 +194,7 @@ MNASNet 0.5 67.734 87.490 .. _MobileNetV3: https://arxiv.org/abs/1905.02244 .. _ResNeXt: https://arxiv.org/abs/1611.05431 .. _MNASNet: https://arxiv.org/abs/1807.11626 +.. _EfficientNet: https://arxiv.org/abs/1905.11946 .. currentmodule:: torchvision.models @@ -267,6 +296,18 @@ MNASNet .. autofunction:: mnasnet1_0 .. autofunction:: mnasnet1_3 +EfficientNet +------------ + +.. autofunction:: efficientnet_b0 +.. autofunction:: efficientnet_b1 +.. autofunction:: efficientnet_b2 +.. autofunction:: efficientnet_b3 +.. autofunction:: efficientnet_b4 +.. autofunction:: efficientnet_b5 +.. autofunction:: efficientnet_b6 +.. autofunction:: efficientnet_b7 + Quantized Models ---------------- diff --git a/hubconf.py b/hubconf.py index 097759bdd89..2bff6850525 100644 --- a/hubconf.py +++ b/hubconf.py @@ -15,6 +15,8 @@ from torchvision.models.mobilenetv3 import mobilenet_v3_large, mobilenet_v3_small from torchvision.models.mnasnet import mnasnet0_5, mnasnet0_75, mnasnet1_0, \ mnasnet1_3 +from torchvision.models.efficientnet import efficientnet_b0, efficientnet_b1, efficientnet_b2, \ + efficientnet_b3, efficientnet_b4, efficientnet_b5, efficientnet_b6, efficientnet_b7 # segmentation from torchvision.models.segmentation import fcn_resnet50, fcn_resnet101, \ diff --git a/references/classification/README.md b/references/classification/README.md index e0b7f210175..210a63c0bca 100644 --- a/references/classification/README.md +++ b/references/classification/README.md @@ -68,6 +68,12 @@ Then we averaged the parameters of the last 3 checkpoints that improved the Acc@ and [#3354](https://github.com/pytorch/vision/pull/3354) for details. +### EfficientNet + +The weights of the B0-B4 variants are ported from Ross Wightman's [timm repo](https://github.com/rwightman/pytorch-image-models/blob/01cb46a9a50e3ba4be167965b5764e9702f09b30/timm/models/efficientnet.py#L95-L108). + +The weights of the B5-B7 variants are ported from Luke Melas' [EfficientNet-PyTorch repo](https://github.com/lukemelas/EfficientNet-PyTorch/blob/1039e009545d9329ea026c9f7541341439712b96/efficientnet_pytorch/utils.py#L562-L564). + ## Mixed precision training Automatic Mixed Precision (AMP) training on GPU for Pytorch can be enabled with the [NVIDIA Apex extension](https://github.com/NVIDIA/apex). diff --git a/references/classification/presets.py b/references/classification/presets.py index 6bb389ba8db..ce5a6fe414f 100644 --- a/references/classification/presets.py +++ b/references/classification/presets.py @@ -1,4 +1,5 @@ from torchvision.transforms import autoaugment, transforms +from torchvision.transforms.functional import InterpolationMode class ClassificationPresetTrain: @@ -24,10 +25,11 @@ def __call__(self, img): class ClassificationPresetEval: - def __init__(self, crop_size, resize_size=256, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): + def __init__(self, crop_size, resize_size=256, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), + interpolation=InterpolationMode.BILINEAR): self.transforms = transforms.Compose([ - transforms.Resize(resize_size), + transforms.Resize(resize_size, interpolation=interpolation), transforms.CenterCrop(crop_size), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std), diff --git a/references/classification/train.py b/references/classification/train.py index b4e9d274662..9ba99b3dc54 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -6,6 +6,7 @@ import torch.utils.data from torch import nn import torchvision +from torchvision.transforms.functional import InterpolationMode import presets import utils @@ -82,7 +83,18 @@ def _get_cache_path(filepath): def load_data(traindir, valdir, args): # Data loading code print("Loading data") - resize_size, crop_size = (342, 299) if args.model == 'inception_v3' else (256, 224) + resize_size, crop_size = 256, 224 + interpolation = InterpolationMode.BILINEAR + if args.model == 'inception_v3': + resize_size, crop_size = 342, 299 + elif args.model.startswith('efficientnet_'): + sizes = { + 'b0': (256, 224), 'b1': (256, 240), 'b2': (288, 288), 'b3': (320, 300), + 'b4': (384, 380), 'b5': (456, 456), 'b6': (528, 528), 'b7': (600, 600), + } + e_type = args.model.replace('efficientnet_', '') + resize_size, crop_size = sizes[e_type] + interpolation = InterpolationMode.BICUBIC print("Loading training data") st = time.time() @@ -113,7 +125,8 @@ def load_data(traindir, valdir, args): else: dataset_test = torchvision.datasets.ImageFolder( valdir, - presets.ClassificationPresetEval(crop_size=crop_size, resize_size=resize_size)) + presets.ClassificationPresetEval(crop_size=crop_size, resize_size=resize_size, + interpolation=interpolation)) if args.cache_dataset: print("Saving dataset_test to {}".format(cache_path)) utils.mkdir(os.path.dirname(cache_path)) diff --git a/test/expect/ModelTester.test_efficientnet_b0_expect.pkl b/test/expect/ModelTester.test_efficientnet_b0_expect.pkl new file mode 100644 index 00000000000..1de871ce0fb Binary files /dev/null and b/test/expect/ModelTester.test_efficientnet_b0_expect.pkl differ diff --git a/test/expect/ModelTester.test_efficientnet_b1_expect.pkl b/test/expect/ModelTester.test_efficientnet_b1_expect.pkl new file mode 100644 index 00000000000..1499a97028e Binary files /dev/null and b/test/expect/ModelTester.test_efficientnet_b1_expect.pkl differ diff --git a/test/expect/ModelTester.test_efficientnet_b2_expect.pkl b/test/expect/ModelTester.test_efficientnet_b2_expect.pkl new file mode 100644 index 00000000000..f0aeb8ec122 Binary files /dev/null and b/test/expect/ModelTester.test_efficientnet_b2_expect.pkl differ diff --git a/test/expect/ModelTester.test_efficientnet_b3_expect.pkl b/test/expect/ModelTester.test_efficientnet_b3_expect.pkl new file mode 100644 index 00000000000..989d6782fe7 Binary files /dev/null and b/test/expect/ModelTester.test_efficientnet_b3_expect.pkl differ diff --git a/test/expect/ModelTester.test_efficientnet_b4_expect.pkl b/test/expect/ModelTester.test_efficientnet_b4_expect.pkl new file mode 100644 index 00000000000..f4a0cc04bf0 Binary files /dev/null and b/test/expect/ModelTester.test_efficientnet_b4_expect.pkl differ diff --git a/test/expect/ModelTester.test_efficientnet_b5_expect.pkl b/test/expect/ModelTester.test_efficientnet_b5_expect.pkl new file mode 100644 index 00000000000..7c674259cd9 Binary files /dev/null and b/test/expect/ModelTester.test_efficientnet_b5_expect.pkl differ diff --git a/test/expect/ModelTester.test_efficientnet_b6_expect.pkl b/test/expect/ModelTester.test_efficientnet_b6_expect.pkl new file mode 100644 index 00000000000..dfad29717e4 Binary files /dev/null and b/test/expect/ModelTester.test_efficientnet_b6_expect.pkl differ diff --git a/test/expect/ModelTester.test_efficientnet_b7_expect.pkl b/test/expect/ModelTester.test_efficientnet_b7_expect.pkl new file mode 100644 index 00000000000..965ee61a2ef Binary files /dev/null and b/test/expect/ModelTester.test_efficientnet_b7_expect.pkl differ diff --git a/torchvision/models/__init__.py b/torchvision/models/__init__.py index 283e544e98e..3c1519c1b42 100644 --- a/torchvision/models/__init__.py +++ b/torchvision/models/__init__.py @@ -8,6 +8,7 @@ from .mobilenet import * from .mnasnet import * from .shufflenetv2 import * +from .efficientnet import * from . import segmentation from . import detection from . import video diff --git a/torchvision/models/efficientnet.py b/torchvision/models/efficientnet.py new file mode 100644 index 00000000000..06b2a301b6d --- /dev/null +++ b/torchvision/models/efficientnet.py @@ -0,0 +1,369 @@ +import copy +import math +import torch + +from functools import partial +from torch import nn, Tensor +from torch.nn import functional as F +from typing import Any, Callable, List, Optional, Sequence + +from .._internally_replaced_utils import load_state_dict_from_url +from torchvision.ops import StochasticDepth + +from torchvision.models.mobilenetv2 import ConvBNActivation, _make_divisible + + +__all__ = ["EfficientNet", "efficientnet_b0", "efficientnet_b1", "efficientnet_b2", "efficientnet_b3", + "efficientnet_b4", "efficientnet_b5", "efficientnet_b6", "efficientnet_b7"] + + +model_urls = { + # Weights ported from https://github.com/rwightman/pytorch-image-models/ + "efficientnet_b0": "https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth", + "efficientnet_b1": "https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth", + "efficientnet_b2": "https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth", + "efficientnet_b3": "https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth", + "efficientnet_b4": "https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth", + # Weights ported from https://github.com/lukemelas/EfficientNet-PyTorch/ + "efficientnet_b5": "https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth", + "efficientnet_b6": "https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth", + "efficientnet_b7": "https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth", +} + + +class SqueezeExcitation(nn.Module): + def __init__(self, input_channels: int, squeeze_channels: int): + super().__init__() + self.fc1 = nn.Conv2d(input_channels, squeeze_channels, 1) + self.fc2 = nn.Conv2d(squeeze_channels, input_channels, 1) + + def _scale(self, input: Tensor) -> Tensor: + scale = F.adaptive_avg_pool2d(input, 1) + scale = self.fc1(scale) + scale = F.silu(scale, inplace=True) + scale = self.fc2(scale) + return scale.sigmoid() + + def forward(self, input: Tensor) -> Tensor: + scale = self._scale(input) + return scale * input + + +class MBConvConfig: + # Stores information listed at Table 1 of the EfficientNet paper + def __init__(self, + expand_ratio: float, kernel: int, stride: int, + input_channels: int, out_channels: int, num_layers: int, + width_mult: float, depth_mult: float) -> None: + self.expand_ratio = expand_ratio + self.kernel = kernel + self.stride = stride + self.input_channels = self.adjust_channels(input_channels, width_mult) + self.out_channels = self.adjust_channels(out_channels, width_mult) + self.num_layers = self.adjust_depth(num_layers, depth_mult) + + def __repr__(self) -> str: + s = self.__class__.__name__ + '(' + s += 'expand_ratio={expand_ratio}' + s += ', kernel={kernel}' + s += ', stride={stride}' + s += ', input_channels={input_channels}' + s += ', out_channels={out_channels}' + s += ', num_layers={num_layers}' + s += ')' + return s.format(**self.__dict__) + + @staticmethod + def adjust_channels(channels: int, width_mult: float, min_value: Optional[int] = None) -> int: + return _make_divisible(channels * width_mult, 8, min_value) + + @staticmethod + def adjust_depth(num_layers: int, depth_mult: float): + return int(math.ceil(num_layers * depth_mult)) + + +class MBConv(nn.Module): + def __init__(self, cnf: MBConvConfig, stochastic_depth_prob: float, norm_layer: Callable[..., nn.Module], + se_layer: Callable[..., nn.Module] = SqueezeExcitation) -> None: + super().__init__() + + if not (1 <= cnf.stride <= 2): + raise ValueError('illegal stride value') + + self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels + + layers: List[nn.Module] = [] + activation_layer = nn.SiLU + + # expand + expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio) + if expanded_channels != cnf.input_channels: + layers.append(ConvBNActivation(cnf.input_channels, expanded_channels, kernel_size=1, + norm_layer=norm_layer, activation_layer=activation_layer)) + + # depthwise + layers.append(ConvBNActivation(expanded_channels, expanded_channels, kernel_size=cnf.kernel, + stride=cnf.stride, groups=expanded_channels, + norm_layer=norm_layer, activation_layer=activation_layer)) + + # squeeze and excitation + squeeze_channels = max(1, cnf.input_channels // 4) + layers.append(se_layer(expanded_channels, squeeze_channels)) + + # project + layers.append(ConvBNActivation(expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, + activation_layer=nn.Identity)) + + self.block = nn.Sequential(*layers) + self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") + self.out_channels = cnf.out_channels + + def forward(self, input: Tensor) -> Tensor: + result = self.block(input) + if self.use_res_connect: + result = self.stochastic_depth(result) + result += input + return result + + +class EfficientNet(nn.Module): + def __init__( + self, + inverted_residual_setting: List[MBConvConfig], + dropout: float, + stochastic_depth_prob: float = 0.2, + num_classes: int = 1000, + block: Optional[Callable[..., nn.Module]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, + **kwargs: Any + ) -> None: + """ + EfficientNet main class + + Args: + inverted_residual_setting (List[MBConvConfig]): Network structure + dropout (float): The droupout probability + stochastic_depth_prob (float): The stochastic depth probability + num_classes (int): Number of classes + block (Optional[Callable[..., nn.Module]]): Module specifying inverted residual building block for mobilenet + norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use + """ + super().__init__() + + if not inverted_residual_setting: + raise ValueError("The inverted_residual_setting should not be empty") + elif not (isinstance(inverted_residual_setting, Sequence) and + all([isinstance(s, MBConvConfig) for s in inverted_residual_setting])): + raise TypeError("The inverted_residual_setting should be List[MBConvConfig]") + + if block is None: + block = MBConv + + if norm_layer is None: + norm_layer = nn.BatchNorm2d + + layers: List[nn.Module] = [] + + # building first layer + firstconv_output_channels = inverted_residual_setting[0].input_channels + layers.append(ConvBNActivation(3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer, + activation_layer=nn.SiLU)) + + # building inverted residual blocks + total_stage_blocks = sum([cnf.num_layers for cnf in inverted_residual_setting]) + stage_block_id = 0 + for cnf in inverted_residual_setting: + stage: List[nn.Module] = [] + for _ in range(cnf.num_layers): + # copy to avoid modifications. shallow copy is enough + block_cnf = copy.copy(cnf) + + # overwrite info if not the first conv in the stage + if stage: + block_cnf.input_channels = block_cnf.out_channels + block_cnf.stride = 1 + + # adjust stochastic depth probability based on the depth of the stage block + sd_prob = stochastic_depth_prob * float(stage_block_id) / total_stage_blocks + + stage.append(block(block_cnf, sd_prob, norm_layer)) + stage_block_id += 1 + + layers.append(nn.Sequential(*stage)) + + # building last several layers + lastconv_input_channels = inverted_residual_setting[-1].out_channels + lastconv_output_channels = 4 * lastconv_input_channels + layers.append(ConvBNActivation(lastconv_input_channels, lastconv_output_channels, kernel_size=1, + norm_layer=norm_layer, activation_layer=nn.SiLU)) + + self.features = nn.Sequential(*layers) + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.classifier = nn.Sequential( + nn.Dropout(p=dropout, inplace=True), + nn.Linear(lastconv_output_channels, num_classes), + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out') + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + init_range = 1.0 / math.sqrt(m.out_features) + nn.init.uniform_(m.weight, -init_range, init_range) + nn.init.zeros_(m.bias) + + def _forward_impl(self, x: Tensor) -> Tensor: + x = self.features(x) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + + x = self.classifier(x) + + return x + + def forward(self, x: Tensor) -> Tensor: + return self._forward_impl(x) + + +def _efficientnet_conf(width_mult: float, depth_mult: float, **kwargs: Any) -> List[MBConvConfig]: + bneck_conf = partial(MBConvConfig, width_mult=width_mult, depth_mult=depth_mult) + inverted_residual_setting = [ + bneck_conf(1, 3, 1, 32, 16, 1), + bneck_conf(6, 3, 2, 16, 24, 2), + bneck_conf(6, 5, 2, 24, 40, 2), + bneck_conf(6, 3, 2, 40, 80, 3), + bneck_conf(6, 5, 1, 80, 112, 3), + bneck_conf(6, 5, 2, 112, 192, 4), + bneck_conf(6, 3, 1, 192, 320, 1), + ] + return inverted_residual_setting + + +def _efficientnet_model( + arch: str, + inverted_residual_setting: List[MBConvConfig], + dropout: float, + pretrained: bool, + progress: bool, + **kwargs: Any +) -> EfficientNet: + model = EfficientNet(inverted_residual_setting, dropout, **kwargs) + if pretrained: + if model_urls.get(arch, None) is None: + raise ValueError("No checkpoint is available for model type {}".format(arch)) + state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) + model.load_state_dict(state_dict) + return model + + +def efficientnet_b0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: + """ + Constructs a EfficientNet B0 architecture from + `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + inverted_residual_setting = _efficientnet_conf(width_mult=1.0, depth_mult=1.0, **kwargs) + return _efficientnet_model("efficientnet_b0", inverted_residual_setting, 0.2, pretrained, progress, **kwargs) + + +def efficientnet_b1(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: + """ + Constructs a EfficientNet B1 architecture from + `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + inverted_residual_setting = _efficientnet_conf(width_mult=1.0, depth_mult=1.1, **kwargs) + return _efficientnet_model("efficientnet_b1", inverted_residual_setting, 0.2, pretrained, progress, **kwargs) + + +def efficientnet_b2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: + """ + Constructs a EfficientNet B2 architecture from + `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + inverted_residual_setting = _efficientnet_conf(width_mult=1.1, depth_mult=1.2, **kwargs) + return _efficientnet_model("efficientnet_b2", inverted_residual_setting, 0.3, pretrained, progress, **kwargs) + + +def efficientnet_b3(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: + """ + Constructs a EfficientNet B3 architecture from + `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + inverted_residual_setting = _efficientnet_conf(width_mult=1.2, depth_mult=1.4, **kwargs) + return _efficientnet_model("efficientnet_b3", inverted_residual_setting, 0.3, pretrained, progress, **kwargs) + + +def efficientnet_b4(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: + """ + Constructs a EfficientNet B4 architecture from + `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + inverted_residual_setting = _efficientnet_conf(width_mult=1.4, depth_mult=1.8, **kwargs) + return _efficientnet_model("efficientnet_b4", inverted_residual_setting, 0.4, pretrained, progress, **kwargs) + + +def efficientnet_b5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: + """ + Constructs a EfficientNet B5 architecture from + `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + inverted_residual_setting = _efficientnet_conf(width_mult=1.6, depth_mult=2.2, **kwargs) + return _efficientnet_model("efficientnet_b5", inverted_residual_setting, 0.4, pretrained, progress, + norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), **kwargs) + + +def efficientnet_b6(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: + """ + Constructs a EfficientNet B6 architecture from + `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + inverted_residual_setting = _efficientnet_conf(width_mult=1.8, depth_mult=2.6, **kwargs) + return _efficientnet_model("efficientnet_b6", inverted_residual_setting, 0.5, pretrained, progress, + norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), **kwargs) + + +def efficientnet_b7(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: + """ + Constructs a EfficientNet B7 architecture from + `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + inverted_residual_setting = _efficientnet_conf(width_mult=2.0, depth_mult=3.1, **kwargs) + return _efficientnet_model("efficientnet_b7", inverted_residual_setting, 0.5, pretrained, progress, + norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), **kwargs) diff --git a/torchvision/ops/stochastic_depth.py b/torchvision/ops/stochastic_depth.py index f3338242a76..0b95e7cca67 100644 --- a/torchvision/ops/stochastic_depth.py +++ b/torchvision/ops/stochastic_depth.py @@ -22,12 +22,12 @@ def stochastic_depth(input: Tensor, p: float, mode: str, training: bool = True) """ if p < 0.0 or p > 1.0: raise ValueError("drop probability has to be between 0 and 1, but got {}".format(p)) + if mode not in ["batch", "row"]: + raise ValueError("mode has to be either 'batch' or 'row', but got {}".format(mode)) if not training or p == 0.0: return input survival_rate = 1.0 - p - if mode not in ["batch", "row"]: - raise ValueError("mode has to be either 'batch' or 'row', but got {}".format(mode)) size = [1] * input.ndim if mode == "row": size[0] = input.shape[0]