Skip to content

Commit

Permalink
Implement MBConv.
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Aug 19, 2021
1 parent e173b8f commit 4502083
Showing 1 changed file with 54 additions and 10 deletions.
64 changes: 54 additions & 10 deletions torchvision/models/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from torch import nn, Tensor
from torch.nn import functional as F
from typing import Any, Optional
from typing import Any, Callable, List, Optional

from .._internally_replaced_utils import load_state_dict_from_url

Expand All @@ -11,33 +11,77 @@
from torchvision.models.mobilenetv3 import SqueezeExcitation


__all__ = []
__all__ = ["EfficientNet"]


model_urls = {}
model_urls = {
"efficientnet_b0": "", # TOD: Add weights
}


def drop_connect(x: Tensor, rate: float):
keep = torch.rand(size=(x.size(0), ), dtype=x.dtype, device=x.device) > rate
keep = keep[(None, ) * (x.ndim - 1)].T
return (x / (1.0 - rate)) * keep


class MBConvConfig:
# TODO: Add dilation for supporting detection and segmentation pipelines
def __init__(self,
kernel: int, stride: int,
input_channels: int, out_channels: int, expand_ratio: float, se_ratio: float,
skip: bool, width_mult: float):
kernel: int, stride: int, dilation: int,
input_channels: int, out_channels: int, expand_ratio: float,
width_mult: float):
self.kernel = kernel
self.stride = stride
self.dilation = dilation
self.input_channels = self.adjust_channels(input_channels, width_mult)
self.out_channels = self.adjust_channels(out_channels, width_mult)
self.expanded_channels = self.adjust_channels(input_channels, expand_ratio * width_mult)
self.se_channels = self.adjust_channels(input_channels, se_ratio * width_mult, 1)
self.skip = skip

@staticmethod
def adjust_channels(channels: int, width_mult: float, min_value: Optional[int] = None):
return _make_divisible(channels * width_mult, 8, min_value)


class MBConv(nn.Module):
pass
def __init__(self, cnf: MBConvConfig, norm_layer: Callable[..., nn.Module],
se_layer: Callable[..., nn.Module] = SqueezeExcitation):
super().__init__()
if not (1 <= cnf.stride <= 2):
raise ValueError('illegal stride value')

self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels

layers: List[nn.Module] = []
activation_layer = nn.SiLU

# expand
if cnf.expanded_channels != cnf.input_channels:
layers.append(ConvBNActivation(cnf.input_channels, cnf.expanded_channels, kernel_size=1,
norm_layer=norm_layer, activation_layer=activation_layer))

# depthwise
stride = 1 if cnf.dilation > 1 else cnf.stride
layers.append(ConvBNActivation(cnf.expanded_channels, cnf.expanded_channels, kernel_size=cnf.kernel,
stride=stride, dilation=cnf.dilation, groups=cnf.expanded_channels,
norm_layer=norm_layer, activation_layer=activation_layer))

# squeeze and excitation
layers.append(se_layer(cnf.expanded_channels, min_value=1, activation_fn=F.sigmoid))

# project
layers.append(ConvBNActivation(cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer,
activation_layer=nn.Identity))

self.block = nn.Sequential(*layers)
self.out_channels = cnf.out_channels

def forward(self, input: Tensor, drop_connect_rate: Optional[float] = None) -> Tensor:
result = self.block(input)
if self.use_res_connect:
if self.training and drop_connect_rate:
result = drop_connect(result, drop_connect_rate)
result += input
return result


class EfficientNet(nn.Module):
Expand Down

0 comments on commit 4502083

Please sign in to comment.