diff --git a/torchvision/models/mnasnet.py b/torchvision/models/mnasnet.py index 6f403fb5e30..b1da02f4697 100644 --- a/torchvision/models/mnasnet.py +++ b/torchvision/models/mnasnet.py @@ -235,8 +235,20 @@ class MNASNet0_5_Weights(WeightsEnum): class MNASNet0_75_Weights(WeightsEnum): - # If a default model is added here the corresponding changes need to be done in mnasnet0_75 - pass + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/mnasnet0_75-7090bc5f.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "recipe": "https://github.com/pytorch/vision/pull/6019", + "num_params": 3170208, + "metrics": { + "acc@1": 71.180, + "acc@5": 90.496, + }, + }, + ) + DEFAULT = IMAGENET1K_V1 class MNASNet1_0_Weights(WeightsEnum): @@ -256,8 +268,20 @@ class MNASNet1_0_Weights(WeightsEnum): class MNASNet1_3_Weights(WeightsEnum): - # If a default model is added here the corresponding changes need to be done in mnasnet1_3 - pass + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/mnasnet1_3-a4c69d6f.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "recipe": "https://github.com/pytorch/vision/pull/6019", + "num_params": 6282256, + "metrics": { + "acc@1": 76.506, + "acc@5": 93.522, + }, + }, + ) + DEFAULT = IMAGENET1K_V1 def _mnasnet(alpha: float, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any) -> MNASNet: @@ -299,15 +323,17 @@ def mnasnet0_5(*, weights: Optional[MNASNet0_5_Weights] = None, progress: bool = return _mnasnet(0.5, weights, progress, **kwargs) -@handle_legacy_interface(weights=("pretrained", None)) +@handle_legacy_interface(weights=("pretrained", MNASNet0_75_Weights.IMAGENET1K_V1)) def mnasnet0_75(*, weights: Optional[MNASNet0_75_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: """MNASNet with depth multiplier of 0.75 from `MnasNet: Platform-Aware Neural Architecture Search for Mobile `_ paper. Args: - weights (:class:`~torchvision.models.MNASNet0_75_Weights`, optional): Currently - no pre-trained weights are available and by default no pre-trained + weights (:class:`~torchvision.models.MNASNet0_75_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.MNASNet0_75_Weights` below for + more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. @@ -351,15 +377,17 @@ def mnasnet1_0(*, weights: Optional[MNASNet1_0_Weights] = None, progress: bool = return _mnasnet(1.0, weights, progress, **kwargs) -@handle_legacy_interface(weights=("pretrained", None)) +@handle_legacy_interface(weights=("pretrained", MNASNet1_3_Weights.IMAGENET1K_V1)) def mnasnet1_3(*, weights: Optional[MNASNet1_3_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: """MNASNet with depth multiplier of 1.3 from `MnasNet: Platform-Aware Neural Architecture Search for Mobile `_ paper. Args: - weights (:class:`~torchvision.models.MNASNet1_3_Weights`, optional): Currently - no pre-trained weights are available and by default no pre-trained + weights (:class:`~torchvision.models.MNASNet1_3_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.MNASNet1_3_Weights` below for + more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.