Skip to content

Commit

Permalink
Add EfficientNet Architecture in TorchVision (#4293)
Browse files Browse the repository at this point in the history
* Adding code skeleton

* Adding MBConvConfig.

* Extend SqueezeExcitation to support custom min_value and activation.

* Implement MBConv.

* Replace stochastic_depth with operator.

* Adding the rest of the EfficientNet implementation

* Update torchvision/models/efficientnet.py

* Replacing 1st activation of SE with SiLU.

* Adding efficientnet_b3.

* Replace mobilenetv3 assets with custom.

* Switch to standard sigmoid and reconfiguring BN.

* Reconfiguration of efficientnet.

* Add repr

* Add weights.

* Update weights.

* Adding B5-B7 weights.

* Update docs and hubconf.

* Fix doc link.

* Fix typo on comment.
  • Loading branch information
datumbox authored Aug 26, 2021
1 parent d004d77 commit 37a9ee5
Show file tree
Hide file tree
Showing 16 changed files with 441 additions and 7 deletions.
43 changes: 42 additions & 1 deletion docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ architectures for image classification:
- `ResNeXt`_
- `Wide ResNet`_
- `MNASNet`_
- `EfficientNet`_

You can construct a model with random weights by calling its constructor:

Expand All @@ -47,6 +48,14 @@ You can construct a model with random weights by calling its constructor:
resnext50_32x4d = models.resnext50_32x4d()
wide_resnet50_2 = models.wide_resnet50_2()
mnasnet = models.mnasnet1_0()
efficientnet_b0 = models.efficientnet_b0()
efficientnet_b1 = models.efficientnet_b1()
efficientnet_b2 = models.efficientnet_b2()
efficientnet_b3 = models.efficientnet_b3()
efficientnet_b4 = models.efficientnet_b4()
efficientnet_b5 = models.efficientnet_b5()
efficientnet_b6 = models.efficientnet_b6()
efficientnet_b7 = models.efficientnet_b7()
We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`.
These can be constructed by passing ``pretrained=True``:
Expand All @@ -68,6 +77,14 @@ These can be constructed by passing ``pretrained=True``:
resnext50_32x4d = models.resnext50_32x4d(pretrained=True)
wide_resnet50_2 = models.wide_resnet50_2(pretrained=True)
mnasnet = models.mnasnet1_0(pretrained=True)
efficientnet_b0 = models.efficientnet_b0(pretrained=True)
efficientnet_b1 = models.efficientnet_b1(pretrained=True)
efficientnet_b2 = models.efficientnet_b2(pretrained=True)
efficientnet_b3 = models.efficientnet_b3(pretrained=True)
efficientnet_b4 = models.efficientnet_b4(pretrained=True)
efficientnet_b5 = models.efficientnet_b5(pretrained=True)
efficientnet_b6 = models.efficientnet_b6(pretrained=True)
efficientnet_b7 = models.efficientnet_b7(pretrained=True)
Instancing a pre-trained model will download its weights to a cache directory.
This directory can be set using the `TORCH_MODEL_ZOO` environment variable. See
Expand Down Expand Up @@ -113,7 +130,10 @@ Unfortunately, the concrete `subset` that was used is lost. For more
information see `this discussion <https://github.com/pytorch/vision/issues/1439>`_
or `these experiments <https://github.com/pytorch/vision/pull/1965>`_.

ImageNet 1-crop error rates (224x224)
The sizes of the EfficientNet models depend on the variant. For the exact input sizes
`check here <https://github.com/pytorch/vision/blob/d2bfd639e46e1c5dc3c177f889dc7750c8d137c7/references/classification/train.py#L92-L93>`_

ImageNet 1-crop error rates

================================ ============= =============
Model Acc@1 Acc@5
Expand Down Expand Up @@ -151,6 +171,14 @@ Wide ResNet-50-2 78.468 94.086
Wide ResNet-101-2 78.848 94.284
MNASNet 1.0 73.456 91.510
MNASNet 0.5 67.734 87.490
EfficientNet-B0 77.692 93.532
EfficientNet-B1 78.642 94.186
EfficientNet-B2 80.608 95.310
EfficientNet-B3 82.008 96.054
EfficientNet-B4 83.384 96.594
EfficientNet-B5 83.444 96.628
EfficientNet-B6 84.008 96.916
EfficientNet-B7 84.122 96.908
================================ ============= =============


Expand All @@ -166,6 +194,7 @@ MNASNet 0.5 67.734 87.490
.. _MobileNetV3: https://arxiv.org/abs/1905.02244
.. _ResNeXt: https://arxiv.org/abs/1611.05431
.. _MNASNet: https://arxiv.org/abs/1807.11626
.. _EfficientNet: https://arxiv.org/abs/1905.11946

.. currentmodule:: torchvision.models

Expand Down Expand Up @@ -267,6 +296,18 @@ MNASNet
.. autofunction:: mnasnet1_0
.. autofunction:: mnasnet1_3

EfficientNet
------------

.. autofunction:: efficientnet_b0
.. autofunction:: efficientnet_b1
.. autofunction:: efficientnet_b2
.. autofunction:: efficientnet_b3
.. autofunction:: efficientnet_b4
.. autofunction:: efficientnet_b5
.. autofunction:: efficientnet_b6
.. autofunction:: efficientnet_b7

Quantized Models
----------------

Expand Down
2 changes: 2 additions & 0 deletions hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from torchvision.models.mobilenetv3 import mobilenet_v3_large, mobilenet_v3_small
from torchvision.models.mnasnet import mnasnet0_5, mnasnet0_75, mnasnet1_0, \
mnasnet1_3
from torchvision.models.efficientnet import efficientnet_b0, efficientnet_b1, efficientnet_b2, \
efficientnet_b3, efficientnet_b4, efficientnet_b5, efficientnet_b6, efficientnet_b7

# segmentation
from torchvision.models.segmentation import fcn_resnet50, fcn_resnet101, \
Expand Down
6 changes: 6 additions & 0 deletions references/classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ Then we averaged the parameters of the last 3 checkpoints that improved the Acc@
and [#3354](https://github.com/pytorch/vision/pull/3354) for details.


### EfficientNet

The weights of the B0-B4 variants are ported from Ross Wightman's [timm repo](https://github.com/rwightman/pytorch-image-models/blob/01cb46a9a50e3ba4be167965b5764e9702f09b30/timm/models/efficientnet.py#L95-L108).

The weights of the B5-B7 variants are ported from Luke Melas' [EfficientNet-PyTorch repo](https://github.com/lukemelas/EfficientNet-PyTorch/blob/1039e009545d9329ea026c9f7541341439712b96/efficientnet_pytorch/utils.py#L562-L564).

## Mixed precision training
Automatic Mixed Precision (AMP) training on GPU for Pytorch can be enabled with the [NVIDIA Apex extension](https://github.com/NVIDIA/apex).

Expand Down
6 changes: 4 additions & 2 deletions references/classification/presets.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from torchvision.transforms import autoaugment, transforms
from torchvision.transforms.functional import InterpolationMode


class ClassificationPresetTrain:
Expand All @@ -24,10 +25,11 @@ def __call__(self, img):


class ClassificationPresetEval:
def __init__(self, crop_size, resize_size=256, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
def __init__(self, crop_size, resize_size=256, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225),
interpolation=InterpolationMode.BILINEAR):

self.transforms = transforms.Compose([
transforms.Resize(resize_size),
transforms.Resize(resize_size, interpolation=interpolation),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
Expand Down
17 changes: 15 additions & 2 deletions references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch.utils.data
from torch import nn
import torchvision
from torchvision.transforms.functional import InterpolationMode

import presets
import utils
Expand Down Expand Up @@ -82,7 +83,18 @@ def _get_cache_path(filepath):
def load_data(traindir, valdir, args):
# Data loading code
print("Loading data")
resize_size, crop_size = (342, 299) if args.model == 'inception_v3' else (256, 224)
resize_size, crop_size = 256, 224
interpolation = InterpolationMode.BILINEAR
if args.model == 'inception_v3':
resize_size, crop_size = 342, 299
elif args.model.startswith('efficientnet_'):
sizes = {
'b0': (256, 224), 'b1': (256, 240), 'b2': (288, 288), 'b3': (320, 300),
'b4': (384, 380), 'b5': (456, 456), 'b6': (528, 528), 'b7': (600, 600),
}
e_type = args.model.replace('efficientnet_', '')
resize_size, crop_size = sizes[e_type]
interpolation = InterpolationMode.BICUBIC

print("Loading training data")
st = time.time()
Expand Down Expand Up @@ -113,7 +125,8 @@ def load_data(traindir, valdir, args):
else:
dataset_test = torchvision.datasets.ImageFolder(
valdir,
presets.ClassificationPresetEval(crop_size=crop_size, resize_size=resize_size))
presets.ClassificationPresetEval(crop_size=crop_size, resize_size=resize_size,
interpolation=interpolation))
if args.cache_dataset:
print("Saving dataset_test to {}".format(cache_path))
utils.mkdir(os.path.dirname(cache_path))
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
1 change: 1 addition & 0 deletions torchvision/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .mobilenet import *
from .mnasnet import *
from .shufflenetv2 import *
from .efficientnet import *
from . import segmentation
from . import detection
from . import video
Expand Down
Loading

0 comments on commit 37a9ee5

Please sign in to comment.