Skip to content

Commit

Permalink
Adding Weights classes for Resnet classification models (#4655)
Browse files Browse the repository at this point in the history
* adding Weights classes for Resnet classification models

* Replacing BasicBlock by Bottleneck in all but 3 model contructors

* adding tests for prototype models

* fixing typo in environment variable

* Update test/test_prototype_models.py

Co-authored-by: Vasilis Vryniotis <[email protected]>

* changing default value for PYTORCH_TEST_WITH_PROTOTYPE

* adding checks to compare outputs of the prototype vs old models

* refactoring prototype tests

* removing unused imports

* applying ufmt

* Update test/test_prototype_models.py

Co-authored-by: Vasilis Vryniotis <[email protected]>

* Update test/test_prototype_models.py

Co-authored-by: Vasilis Vryniotis <[email protected]>

* Update test/test_prototype_models.py

Co-authored-by: Vasilis Vryniotis <[email protected]>

* Update test/test_prototype_models.py

Co-authored-by: Vasilis Vryniotis <[email protected]>

* Update test/test_prototype_models.py

Co-authored-by: Vasilis Vryniotis <[email protected]>

* Update test/test_prototype_models.py

Co-authored-by: Vasilis Vryniotis <[email protected]>

* Update test/test_prototype_models.py

Co-authored-by: Vasilis Vryniotis <[email protected]>

Co-authored-by: Vasilis Vryniotis <[email protected]>
  • Loading branch information
jdsgomes and datumbox authored Oct 20, 2021
1 parent d18c487 commit 80e6aff
Show file tree
Hide file tree
Showing 2 changed files with 261 additions and 1 deletion.
54 changes: 54 additions & 0 deletions test/test_prototype_models.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,56 @@
import os

import pytest
import torch
from common_utils import set_rng_seed, cpu_and_gpu
from test_models import _assert_expected, _model_params
from torchvision import models as original_models
from torchvision.prototype import models


def get_available_classification_models():
return [k for k, v in models.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]


@pytest.mark.parametrize("model_name", get_available_classification_models())
@pytest.mark.parametrize("dev", cpu_and_gpu())
@pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "0") == "0", reason="Prototype code tests are disabled")
def test_classification_model(model_name, dev):
set_rng_seed(0)
defaults = {
"num_classes": 50,
"input_shape": (1, 3, 224, 224),
}
kwargs = {**defaults, **_model_params.get(model_name, {})}
input_shape = kwargs.pop("input_shape")
model = models.__dict__[model_name](**kwargs)
model.eval().to(device=dev)
x = torch.rand(input_shape).to(device=dev)
out = model(x)
_assert_expected(out.cpu(), model_name, prec=0.1)
assert out.shape[-1] == 50


@pytest.mark.parametrize("model_name", get_available_classification_models())
@pytest.mark.parametrize("dev", cpu_and_gpu())
@pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "0") == "0", reason="Prototype code tests are disabled")
def test_old_vs_new_classification_factory(model_name, dev):
defaults = {
"pretrained": True,
"input_shape": (1, 3, 224, 224),
}
kwargs = {**defaults, **_model_params.get(model_name, {})}
input_shape = kwargs.pop("input_shape")
model_old = original_models.__dict__[model_name](**kwargs)
model_old.eval().to(device=dev)
x = torch.rand(input_shape).to(device=dev)
out_old = model_old(x)
# compare with new model builder parameterized in the old fashion way
model_new = models.__dict__[model_name](**kwargs)
model_new.eval().to(device=dev)
out_new = model_new(x)
torch.testing.assert_close(out_new, out_old, rtol=0.0, atol=0.0, check_dtype=False)


def test_smoke():
import torchvision.prototype.models # noqa: F401
208 changes: 207 additions & 1 deletion torchvision/prototype/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,27 @@
from ._meta import _IMAGENET_CATEGORIES


__all__ = ["ResNet", "ResNet50Weights", "resnet50"]
__all__ = [
"ResNet",
"ResNet18Weights",
"ResNet34Weights",
"ResNet50Weights",
"ResNet101Weights",
"ResNet152Weights",
"ResNeXt50_32x4dWeights",
"ResNeXt101_32x8dWeights",
"WideResNet50_2Weights",
"WideResNet101_2Weights",
"resnet18",
"resnet34",
"resnet50",
"resnet101",
"resnet152",
"resnext50_32x4d",
"resnext101_32x8d",
"wide_resnet50_2",
"wide_resnet101_2",
]


def _resnet(
Expand All @@ -35,6 +55,32 @@ def _resnet(
}


class ResNet18Weights(Weights):
ImageNet1K_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/resnet18-f37072fd.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"recipe": "",
"acc@1": 69.758,
"acc@5": 89.078,
},
)


class ResNet34Weights(Weights):
ImageNet1K_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/resnet34-b627a593.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"recipe": "",
"acc@1": 73.314,
"acc@5": 91.420,
},
)


class ResNet50Weights(Weights):
ImageNet1K_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/resnet50-0676ba61.pth",
Expand All @@ -58,10 +104,170 @@ class ResNet50Weights(Weights):
)


class ResNet101Weights(Weights):
ImageNet1K_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/resnet101-63fe2227.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"recipe": "",
"acc@1": 77.374,
"acc@5": 93.546,
},
)


class ResNet152Weights(Weights):
ImageNet1K_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/resnet152-394f9c45.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"recipe": "",
"acc@1": 78.312,
"acc@5": 94.046,
},
)


class ResNeXt50_32x4dWeights(Weights):
ImageNet1K_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"recipe": "",
"acc@1": 77.618,
"acc@5": 93.698,
},
)


class ResNeXt101_32x8dWeights(Weights):
ImageNet1K_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"recipe": "",
"acc@1": 79.312,
"acc@5": 94.526,
},
)


class WideResNet50_2Weights(Weights):
ImageNet1K_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"recipe": "",
"acc@1": 78.468,
"acc@5": 94.086,
},
)


class WideResNet101_2Weights(Weights):
ImageNet1K_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
**_common_meta,
"recipe": "",
"acc@1": 78.848,
"acc@5": 94.284,
},
)


def resnet18(weights: Optional[ResNet18Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
weights = ResNet18Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None

weights = ResNet18Weights.verify(weights)

return _resnet(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs)


def resnet34(weights: Optional[ResNet34Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
weights = ResNet34Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None

weights = ResNet34Weights.verify(weights)

return _resnet(BasicBlock, [3, 4, 6, 3], weights, progress, **kwargs)


def resnet50(weights: Optional[ResNet50Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
weights = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = ResNet50Weights.verify(weights)

return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)


def resnet101(weights: Optional[ResNet101Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
weights = ResNet101Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None

weights = ResNet101Weights.verify(weights)

return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)


def resnet152(weights: Optional[ResNet152Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
weights = ResNet152Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None

weights = ResNet152Weights.verify(weights)

return _resnet(Bottleneck, [3, 8, 36, 3], weights, progress, **kwargs)


def resnext50_32x4d(weights: Optional[ResNeXt50_32x4dWeights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
weights = ResNeXt50_32x4dWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None

weights = ResNeXt50_32x4dWeights.verify(weights)
kwargs["groups"] = 32
kwargs["width_per_group"] = 4
return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)


def resnext101_32x8d(weights: Optional[ResNeXt101_32x8dWeights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
weights = ResNeXt101_32x8dWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None

weights = ResNeXt101_32x8dWeights.verify(weights)
kwargs["groups"] = 32
kwargs["width_per_group"] = 8
return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)


def wide_resnet50_2(weights: Optional[WideResNet50_2Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
weights = WideResNet50_2Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None

weights = WideResNet50_2Weights.verify(weights)
kwargs["width_per_group"] = 64 * 2
return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)


def wide_resnet101_2(weights: Optional[WideResNet101_2Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
weights = WideResNet101_2Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None

weights = WideResNet101_2Weights.verify(weights)
kwargs["width_per_group"] = 64 * 2
return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)

0 comments on commit 80e6aff

Please sign in to comment.