Skip to content

Commit

Permalink
[fbsync] Add weight for mnasnet0_75 and mnasnet1_3 (#6019)
Browse files Browse the repository at this point in the history
Summary:
* Add weight for mnasnet0_75 and mnasnet1_3

* Fix missing comma

* Add PR url as recipe, and update the metrics

* Add weights to legacy handler

* Update docs to specify there are weights available

Reviewed By: NicolasHug

Differential Revision: D36760931

fbshipit-source-id: 00211a6dd22b4b42a9845b7b4d25337ed14a6349
  • Loading branch information
YosuaMichael authored and facebook-github-bot committed Jun 1, 2022
1 parent dd44586 commit 5fe8887
Showing 1 changed file with 38 additions and 10 deletions.
48 changes: 38 additions & 10 deletions torchvision/models/mnasnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,20 @@ class MNASNet0_5_Weights(WeightsEnum):


class MNASNet0_75_Weights(WeightsEnum):
# If a default model is added here the corresponding changes need to be done in mnasnet0_75
pass
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/mnasnet0_75-7090bc5f.pth",
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/pull/6019",
"num_params": 3170208,
"metrics": {
"acc@1": 71.180,
"acc@5": 90.496,
},
},
)
DEFAULT = IMAGENET1K_V1


class MNASNet1_0_Weights(WeightsEnum):
Expand All @@ -256,8 +268,20 @@ class MNASNet1_0_Weights(WeightsEnum):


class MNASNet1_3_Weights(WeightsEnum):
# If a default model is added here the corresponding changes need to be done in mnasnet1_3
pass
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/mnasnet1_3-a4c69d6f.pth",
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/pull/6019",
"num_params": 6282256,
"metrics": {
"acc@1": 76.506,
"acc@5": 93.522,
},
},
)
DEFAULT = IMAGENET1K_V1


def _mnasnet(alpha: float, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any) -> MNASNet:
Expand Down Expand Up @@ -299,15 +323,17 @@ def mnasnet0_5(*, weights: Optional[MNASNet0_5_Weights] = None, progress: bool =
return _mnasnet(0.5, weights, progress, **kwargs)


@handle_legacy_interface(weights=("pretrained", None))
@handle_legacy_interface(weights=("pretrained", MNASNet0_75_Weights.IMAGENET1K_V1))
def mnasnet0_75(*, weights: Optional[MNASNet0_75_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
"""MNASNet with depth multiplier of 0.75 from
`MnasNet: Platform-Aware Neural Architecture Search for Mobile
<https://arxiv.org/pdf/1807.11626.pdf>`_ paper.
Args:
weights (:class:`~torchvision.models.MNASNet0_75_Weights`, optional): Currently
no pre-trained weights are available and by default no pre-trained
weights (:class:`~torchvision.models.MNASNet0_75_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.MNASNet0_75_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
Expand Down Expand Up @@ -351,15 +377,17 @@ def mnasnet1_0(*, weights: Optional[MNASNet1_0_Weights] = None, progress: bool =
return _mnasnet(1.0, weights, progress, **kwargs)


@handle_legacy_interface(weights=("pretrained", None))
@handle_legacy_interface(weights=("pretrained", MNASNet1_3_Weights.IMAGENET1K_V1))
def mnasnet1_3(*, weights: Optional[MNASNet1_3_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
"""MNASNet with depth multiplier of 1.3 from
`MnasNet: Platform-Aware Neural Architecture Search for Mobile
<https://arxiv.org/pdf/1807.11626.pdf>`_ paper.
Args:
weights (:class:`~torchvision.models.MNASNet1_3_Weights`, optional): Currently
no pre-trained weights are available and by default no pre-trained
weights (:class:`~torchvision.models.MNASNet1_3_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.MNASNet1_3_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
Expand Down

0 comments on commit 5fe8887

Please sign in to comment.