Skip to content

Commit

Permalink
Adding typing.
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Aug 19, 2021
1 parent 4502083 commit 95dedaf
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions torchvision/models/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
}


def drop_connect(x: Tensor, rate: float):
def drop_connect(x: Tensor, rate: float) -> Tensor:
keep = torch.rand(size=(x.size(0), ), dtype=x.dtype, device=x.device) > rate
keep = keep[(None, ) * (x.ndim - 1)].T
return (x / (1.0 - rate)) * keep
Expand All @@ -29,7 +29,7 @@ class MBConvConfig:
def __init__(self,
kernel: int, stride: int, dilation: int,
input_channels: int, out_channels: int, expand_ratio: float,
width_mult: float):
width_mult: float) -> None:
self.kernel = kernel
self.stride = stride
self.dilation = dilation
Expand All @@ -38,13 +38,13 @@ def __init__(self,
self.expanded_channels = self.adjust_channels(input_channels, expand_ratio * width_mult)

@staticmethod
def adjust_channels(channels: int, width_mult: float, min_value: Optional[int] = None):
def adjust_channels(channels: int, width_mult: float, min_value: Optional[int] = None) -> int:
return _make_divisible(channels * width_mult, 8, min_value)


class MBConv(nn.Module):
def __init__(self, cnf: MBConvConfig, norm_layer: Callable[..., nn.Module],
se_layer: Callable[..., nn.Module] = SqueezeExcitation):
se_layer: Callable[..., nn.Module] = SqueezeExcitation) -> None:
super().__init__()
if not (1 <= cnf.stride <= 2):
raise ValueError('illegal stride value')
Expand Down

0 comments on commit 95dedaf

Please sign in to comment.