From 45020831f07af541844e967afcf2b3ab284687d9 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 19 Aug 2021 20:12:07 +0100 Subject: [PATCH] Implement MBConv. --- torchvision/models/efficientnet.py | 64 +++++++++++++++++++++++++----- 1 file changed, 54 insertions(+), 10 deletions(-) diff --git a/torchvision/models/efficientnet.py b/torchvision/models/efficientnet.py index 2d5fd219288..d4a15e0ca60 100644 --- a/torchvision/models/efficientnet.py +++ b/torchvision/models/efficientnet.py @@ -2,7 +2,7 @@ from torch import nn, Tensor from torch.nn import functional as F -from typing import Any, Optional +from typing import Any, Callable, List, Optional from .._internally_replaced_utils import load_state_dict_from_url @@ -11,25 +11,31 @@ from torchvision.models.mobilenetv3 import SqueezeExcitation -__all__ = [] +__all__ = ["EfficientNet"] -model_urls = {} +model_urls = { + "efficientnet_b0": "", # TOD: Add weights +} + + +def drop_connect(x: Tensor, rate: float): + keep = torch.rand(size=(x.size(0), ), dtype=x.dtype, device=x.device) > rate + keep = keep[(None, ) * (x.ndim - 1)].T + return (x / (1.0 - rate)) * keep class MBConvConfig: - # TODO: Add dilation for supporting detection and segmentation pipelines def __init__(self, - kernel: int, stride: int, - input_channels: int, out_channels: int, expand_ratio: float, se_ratio: float, - skip: bool, width_mult: float): + kernel: int, stride: int, dilation: int, + input_channels: int, out_channels: int, expand_ratio: float, + width_mult: float): self.kernel = kernel self.stride = stride + self.dilation = dilation self.input_channels = self.adjust_channels(input_channels, width_mult) self.out_channels = self.adjust_channels(out_channels, width_mult) self.expanded_channels = self.adjust_channels(input_channels, expand_ratio * width_mult) - self.se_channels = self.adjust_channels(input_channels, se_ratio * width_mult, 1) - self.skip = skip @staticmethod def adjust_channels(channels: int, width_mult: float, min_value: Optional[int] = None): @@ -37,7 +43,45 @@ def adjust_channels(channels: int, width_mult: float, min_value: Optional[int] = class MBConv(nn.Module): - pass + def __init__(self, cnf: MBConvConfig, norm_layer: Callable[..., nn.Module], + se_layer: Callable[..., nn.Module] = SqueezeExcitation): + super().__init__() + if not (1 <= cnf.stride <= 2): + raise ValueError('illegal stride value') + + self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels + + layers: List[nn.Module] = [] + activation_layer = nn.SiLU + + # expand + if cnf.expanded_channels != cnf.input_channels: + layers.append(ConvBNActivation(cnf.input_channels, cnf.expanded_channels, kernel_size=1, + norm_layer=norm_layer, activation_layer=activation_layer)) + + # depthwise + stride = 1 if cnf.dilation > 1 else cnf.stride + layers.append(ConvBNActivation(cnf.expanded_channels, cnf.expanded_channels, kernel_size=cnf.kernel, + stride=stride, dilation=cnf.dilation, groups=cnf.expanded_channels, + norm_layer=norm_layer, activation_layer=activation_layer)) + + # squeeze and excitation + layers.append(se_layer(cnf.expanded_channels, min_value=1, activation_fn=F.sigmoid)) + + # project + layers.append(ConvBNActivation(cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, + activation_layer=nn.Identity)) + + self.block = nn.Sequential(*layers) + self.out_channels = cnf.out_channels + + def forward(self, input: Tensor, drop_connect_rate: Optional[float] = None) -> Tensor: + result = self.block(input) + if self.use_res_connect: + if self.training and drop_connect_rate: + result = drop_connect(result, drop_connect_rate) + result += input + return result class EfficientNet(nn.Module):