From 95dedaf3668fb7891152f6afcd846d6b038add79 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 19 Aug 2021 20:13:33 +0100 Subject: [PATCH] Adding typing. --- torchvision/models/efficientnet.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchvision/models/efficientnet.py b/torchvision/models/efficientnet.py index d4a15e0ca60..0f53b0ee04c 100644 --- a/torchvision/models/efficientnet.py +++ b/torchvision/models/efficientnet.py @@ -19,7 +19,7 @@ } -def drop_connect(x: Tensor, rate: float): +def drop_connect(x: Tensor, rate: float) -> Tensor: keep = torch.rand(size=(x.size(0), ), dtype=x.dtype, device=x.device) > rate keep = keep[(None, ) * (x.ndim - 1)].T return (x / (1.0 - rate)) * keep @@ -29,7 +29,7 @@ class MBConvConfig: def __init__(self, kernel: int, stride: int, dilation: int, input_channels: int, out_channels: int, expand_ratio: float, - width_mult: float): + width_mult: float) -> None: self.kernel = kernel self.stride = stride self.dilation = dilation @@ -38,13 +38,13 @@ def __init__(self, self.expanded_channels = self.adjust_channels(input_channels, expand_ratio * width_mult) @staticmethod - def adjust_channels(channels: int, width_mult: float, min_value: Optional[int] = None): + def adjust_channels(channels: int, width_mult: float, min_value: Optional[int] = None) -> int: return _make_divisible(channels * width_mult, 8, min_value) class MBConv(nn.Module): def __init__(self, cnf: MBConvConfig, norm_layer: Callable[..., nn.Module], - se_layer: Callable[..., nn.Module] = SqueezeExcitation): + se_layer: Callable[..., nn.Module] = SqueezeExcitation) -> None: super().__init__() if not (1 <= cnf.stride <= 2): raise ValueError('illegal stride value')