From 8fe72d131d6d2862b9db1efb3ffa2a6ded15efc8 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 14 Oct 2021 10:49:10 +0100 Subject: [PATCH] Multi-pretrained weight support - initial API + ResNet50 (#4610) * Adding lightweight API for models. * Adding resnet50. * Fix preset * Add fake categories. * Fixing mypy. * Add string=>weight conversion support on Enums. * Temporarily hardcoding imagenet categories. * Minor refactoring. --- torchvision/prototype/__init__.py | 2 + torchvision/prototype/models/__init__.py | 1 + torchvision/prototype/models/_api.py | 76 ++ torchvision/prototype/models/_meta.py | 1008 ++++++++++++++++++ torchvision/prototype/models/resnet.py | 67 ++ torchvision/prototype/transforms/__init__.py | 1 + torchvision/prototype/transforms/presets.py | 44 + 7 files changed, 1199 insertions(+) create mode 100644 torchvision/prototype/models/__init__.py create mode 100644 torchvision/prototype/models/_api.py create mode 100644 torchvision/prototype/models/_meta.py create mode 100644 torchvision/prototype/models/resnet.py create mode 100644 torchvision/prototype/transforms/__init__.py create mode 100644 torchvision/prototype/transforms/presets.py diff --git a/torchvision/prototype/__init__.py b/torchvision/prototype/__init__.py index 62a23dfafae..23a2912b00b 100644 --- a/torchvision/prototype/__init__.py +++ b/torchvision/prototype/__init__.py @@ -1 +1,3 @@ from . import datasets +from . import models +from . import transforms diff --git a/torchvision/prototype/models/__init__.py b/torchvision/prototype/models/__init__.py new file mode 100644 index 00000000000..b792ca6ecf7 --- /dev/null +++ b/torchvision/prototype/models/__init__.py @@ -0,0 +1 @@ +from .resnet import * diff --git a/torchvision/prototype/models/_api.py b/torchvision/prototype/models/_api.py new file mode 100644 index 00000000000..bc12a9e2e31 --- /dev/null +++ b/torchvision/prototype/models/_api.py @@ -0,0 +1,76 @@ +from collections import OrderedDict +from dataclasses import dataclass, fields +from enum import Enum +from typing import Any, Callable, Dict + +from ..._internally_replaced_utils import load_state_dict_from_url + + +__all__ = ["Weights", "WeightEntry"] + + +@dataclass +class WeightEntry: + """ + This class is used to group important attributes associated with the pre-trained weights. + + Args: + url (str): The location where we find the weights. + transforms (Callable): A callable that constructs the preprocessing method (or validation preset transforms) + needed to use the model. The reason we attach a constructor method rather than an already constructed + object is because the specific object might have memory and thus we want to delay initialization until + needed. + meta (Dict[str, Any]): Stores meta-data related to the weights of the model and its configuration. These can be + informative attributes (for example the number of parameters/flops, recipe link/methods used in training + etc), configuration parameters (for example the `num_classes`) needed to construct the model or important + meta-data (for example the `classes` of a classification model) needed to use the model. + """ + + url: str + transforms: Callable + meta: Dict[str, Any] + + +class Weights(Enum): + """ + This class is the parent class of all model weights. Each model building method receives an optional `weights` + parameter with its associated pre-trained weights. It inherits from `Enum` and its values should be of type + `WeightEntry`. + + Args: + value (WeightEntry): The data class entry with the weight information. + """ + + def __init__(self, value: WeightEntry): + self._value_ = value + + @classmethod + def verify(cls, obj: Any) -> Any: + if obj is not None: + if type(obj) is str: + obj = cls.from_str(obj) + elif not isinstance(obj, cls) and not isinstance(obj, WeightEntry): + raise TypeError( + f"Invalid Weight class provided; expected {cls.__name__} " f"but received {obj.__class__.__name__}." + ) + return obj + + @classmethod + def from_str(cls, value: str) -> "Weights": + for v in cls: + if v._name_ == value: + return v + raise ValueError(f"Invalid value {value} for enum {cls.__name__}.") + + def state_dict(self, progress: bool) -> OrderedDict: + return load_state_dict_from_url(self.url, progress=progress) + + def __repr__(self): + return f"{self.__class__.__name__}.{self._name_}" + + def __getattr__(self, name): + # Be able to fetch WeightEntry attributes directly + for f in fields(WeightEntry): + if f.name == name: + return object.__getattribute__(self.value, name) + return super().__getattr__(name) diff --git a/torchvision/prototype/models/_meta.py b/torchvision/prototype/models/_meta.py new file mode 100644 index 00000000000..87eb338d51a --- /dev/null +++ b/torchvision/prototype/models/_meta.py @@ -0,0 +1,1008 @@ +""" +This file is part of the private API. Please do not refer to any variables defined here directly as they will be +removed on future versions without warning. +""" + +# This will eventually be replaced with a call at torchvision.datasets.find("imagenet").info.categories +_IMAGENET_CATEGORIES = [ + "tench", + "goldfish", + "great white shark", + "tiger shark", + "hammerhead", + "electric ray", + "stingray", + "cock", + "hen", + "ostrich", + "brambling", + "goldfinch", + "house finch", + "junco", + "indigo bunting", + "robin", + "bulbul", + "jay", + "magpie", + "chickadee", + "water ouzel", + "kite", + "bald eagle", + "vulture", + "great grey owl", + "European fire salamander", + "common newt", + "eft", + "spotted salamander", + "axolotl", + "bullfrog", + "tree frog", + "tailed frog", + "loggerhead", + "leatherback turtle", + "mud turtle", + "terrapin", + "box turtle", + "banded gecko", + "common iguana", + "American chameleon", + "whiptail", + "agama", + "frilled lizard", + "alligator lizard", + "Gila monster", + "green lizard", + "African chameleon", + "Komodo dragon", + "African crocodile", + "American alligator", + "triceratops", + "thunder snake", + "ringneck snake", + "hognose snake", + "green snake", + "king snake", + "garter snake", + "water snake", + "vine snake", + "night snake", + "boa constrictor", + "rock python", + "Indian cobra", + "green mamba", + "sea snake", + "horned viper", + "diamondback", + "sidewinder", + "trilobite", + "harvestman", + "scorpion", + "black and gold garden spider", + "barn spider", + "garden spider", + "black widow", + "tarantula", + "wolf spider", + "tick", + "centipede", + "black grouse", + "ptarmigan", + "ruffed grouse", + "prairie chicken", + "peacock", + "quail", + "partridge", + "African grey", + "macaw", + "sulphur-crested cockatoo", + "lorikeet", + "coucal", + "bee eater", + "hornbill", + "hummingbird", + "jacamar", + "toucan", + "drake", + "red-breasted merganser", + "goose", + "black swan", + "tusker", + "echidna", + "platypus", + "wallaby", + "koala", + "wombat", + "jellyfish", + "sea anemone", + "brain coral", + "flatworm", + "nematode", + "conch", + "snail", + "slug", + "sea slug", + "chiton", + "chambered nautilus", + "Dungeness crab", + "rock crab", + "fiddler crab", + "king crab", + "American lobster", + "spiny lobster", + "crayfish", + "hermit crab", + "isopod", + "white stork", + "black stork", + "spoonbill", + "flamingo", + "little blue heron", + "American egret", + "bittern", + "crane", + "limpkin", + "European gallinule", + "American coot", + "bustard", + "ruddy turnstone", + "red-backed sandpiper", + "redshank", + "dowitcher", + "oystercatcher", + "pelican", + "king penguin", + "albatross", + "grey whale", + "killer whale", + "dugong", + "sea lion", + "Chihuahua", + "Japanese spaniel", + "Maltese dog", + "Pekinese", + "Shih-Tzu", + "Blenheim spaniel", + "papillon", + "toy terrier", + "Rhodesian ridgeback", + "Afghan hound", + "basset", + "beagle", + "bloodhound", + "bluetick", + "black-and-tan coonhound", + "Walker hound", + "English foxhound", + "redbone", + "borzoi", + "Irish wolfhound", + "Italian greyhound", + "whippet", + "Ibizan hound", + "Norwegian elkhound", + "otterhound", + "Saluki", + "Scottish deerhound", + "Weimaraner", + "Staffordshire bullterrier", + "American Staffordshire terrier", + "Bedlington terrier", + "Border terrier", + "Kerry blue terrier", + "Irish terrier", + "Norfolk terrier", + "Norwich terrier", + "Yorkshire terrier", + "wire-haired fox terrier", + "Lakeland terrier", + "Sealyham terrier", + "Airedale", + "cairn", + "Australian terrier", + "Dandie Dinmont", + "Boston bull", + "miniature schnauzer", + "giant schnauzer", + "standard schnauzer", + "Scotch terrier", + "Tibetan terrier", + "silky terrier", + "soft-coated wheaten terrier", + "West Highland white terrier", + "Lhasa", + "flat-coated retriever", + "curly-coated retriever", + "golden retriever", + "Labrador retriever", + "Chesapeake Bay retriever", + "German short-haired pointer", + "vizsla", + "English setter", + "Irish setter", + "Gordon setter", + "Brittany spaniel", + "clumber", + "English springer", + "Welsh springer spaniel", + "cocker spaniel", + "Sussex spaniel", + "Irish water spaniel", + "kuvasz", + "schipperke", + "groenendael", + "malinois", + "briard", + "kelpie", + "komondor", + "Old English sheepdog", + "Shetland sheepdog", + "collie", + "Border collie", + "Bouvier des Flandres", + "Rottweiler", + "German shepherd", + "Doberman", + "miniature pinscher", + "Greater Swiss Mountain dog", + "Bernese mountain dog", + "Appenzeller", + "EntleBucher", + "boxer", + "bull mastiff", + "Tibetan mastiff", + "French bulldog", + "Great Dane", + "Saint Bernard", + "Eskimo dog", + "malamute", + "Siberian husky", + "dalmatian", + "affenpinscher", + "basenji", + "pug", + "Leonberg", + "Newfoundland", + "Great Pyrenees", + "Samoyed", + "Pomeranian", + "chow", + "keeshond", + "Brabancon griffon", + "Pembroke", + "Cardigan", + "toy poodle", + "miniature poodle", + "standard poodle", + "Mexican hairless", + "timber wolf", + "white wolf", + "red wolf", + "coyote", + "dingo", + "dhole", + "African hunting dog", + "hyena", + "red fox", + "kit fox", + "Arctic fox", + "grey fox", + "tabby", + "tiger cat", + "Persian cat", + "Siamese cat", + "Egyptian cat", + "cougar", + "lynx", + "leopard", + "snow leopard", + "jaguar", + "lion", + "tiger", + "cheetah", + "brown bear", + "American black bear", + "ice bear", + "sloth bear", + "mongoose", + "meerkat", + "tiger beetle", + "ladybug", + "ground beetle", + "long-horned beetle", + "leaf beetle", + "dung beetle", + "rhinoceros beetle", + "weevil", + "fly", + "bee", + "ant", + "grasshopper", + "cricket", + "walking stick", + "cockroach", + "mantis", + "cicada", + "leafhopper", + "lacewing", + "dragonfly", + "damselfly", + "admiral", + "ringlet", + "monarch", + "cabbage butterfly", + "sulphur butterfly", + "lycaenid", + "starfish", + "sea urchin", + "sea cucumber", + "wood rabbit", + "hare", + "Angora", + "hamster", + "porcupine", + "fox squirrel", + "marmot", + "beaver", + "guinea pig", + "sorrel", + "zebra", + "hog", + "wild boar", + "warthog", + "hippopotamus", + "ox", + "water buffalo", + "bison", + "ram", + "bighorn", + "ibex", + "hartebeest", + "impala", + "gazelle", + "Arabian camel", + "llama", + "weasel", + "mink", + "polecat", + "black-footed ferret", + "otter", + "skunk", + "badger", + "armadillo", + "three-toed sloth", + "orangutan", + "gorilla", + "chimpanzee", + "gibbon", + "siamang", + "guenon", + "patas", + "baboon", + "macaque", + "langur", + "colobus", + "proboscis monkey", + "marmoset", + "capuchin", + "howler monkey", + "titi", + "spider monkey", + "squirrel monkey", + "Madagascar cat", + "indri", + "Indian elephant", + "African elephant", + "lesser panda", + "giant panda", + "barracouta", + "eel", + "coho", + "rock beauty", + "anemone fish", + "sturgeon", + "gar", + "lionfish", + "puffer", + "abacus", + "abaya", + "academic gown", + "accordion", + "acoustic guitar", + "aircraft carrier", + "airliner", + "airship", + "altar", + "ambulance", + "amphibian", + "analog clock", + "apiary", + "apron", + "ashcan", + "assault rifle", + "backpack", + "bakery", + "balance beam", + "balloon", + "ballpoint", + "Band Aid", + "banjo", + "bannister", + "barbell", + "barber chair", + "barbershop", + "barn", + "barometer", + "barrel", + "barrow", + "baseball", + "basketball", + "bassinet", + "bassoon", + "bathing cap", + "bath towel", + "bathtub", + "beach wagon", + "beacon", + "beaker", + "bearskin", + "beer bottle", + "beer glass", + "bell cote", + "bib", + "bicycle-built-for-two", + "bikini", + "binder", + "binoculars", + "birdhouse", + "boathouse", + "bobsled", + "bolo tie", + "bonnet", + "bookcase", + "bookshop", + "bottlecap", + "bow", + "bow tie", + "brass", + "brassiere", + "breakwater", + "breastplate", + "broom", + "bucket", + "buckle", + "bulletproof vest", + "bullet train", + "butcher shop", + "cab", + "caldron", + "candle", + "cannon", + "canoe", + "can opener", + "cardigan", + "car mirror", + "carousel", + "carpenter's kit", + "carton", + "car wheel", + "cash machine", + "cassette", + "cassette player", + "castle", + "catamaran", + "CD player", + "cello", + "cellular telephone", + "chain", + "chainlink fence", + "chain mail", + "chain saw", + "chest", + "chiffonier", + "chime", + "china cabinet", + "Christmas stocking", + "church", + "cinema", + "cleaver", + "cliff dwelling", + "cloak", + "clog", + "cocktail shaker", + "coffee mug", + "coffeepot", + "coil", + "combination lock", + "computer keyboard", + "confectionery", + "container ship", + "convertible", + "corkscrew", + "cornet", + "cowboy boot", + "cowboy hat", + "cradle", + "crane", + "crash helmet", + "crate", + "crib", + "Crock Pot", + "croquet ball", + "crutch", + "cuirass", + "dam", + "desk", + "desktop computer", + "dial telephone", + "diaper", + "digital clock", + "digital watch", + "dining table", + "dishrag", + "dishwasher", + "disk brake", + "dock", + "dogsled", + "dome", + "doormat", + "drilling platform", + "drum", + "drumstick", + "dumbbell", + "Dutch oven", + "electric fan", + "electric guitar", + "electric locomotive", + "entertainment center", + "envelope", + "espresso maker", + "face powder", + "feather boa", + "file", + "fireboat", + "fire engine", + "fire screen", + "flagpole", + "flute", + "folding chair", + "football helmet", + "forklift", + "fountain", + "fountain pen", + "four-poster", + "freight car", + "French horn", + "frying pan", + "fur coat", + "garbage truck", + "gasmask", + "gas pump", + "goblet", + "go-kart", + "golf ball", + "golfcart", + "gondola", + "gong", + "gown", + "grand piano", + "greenhouse", + "grille", + "grocery store", + "guillotine", + "hair slide", + "hair spray", + "half track", + "hammer", + "hamper", + "hand blower", + "hand-held computer", + "handkerchief", + "hard disc", + "harmonica", + "harp", + "harvester", + "hatchet", + "holster", + "home theater", + "honeycomb", + "hook", + "hoopskirt", + "horizontal bar", + "horse cart", + "hourglass", + "iPod", + "iron", + "jack-o'-lantern", + "jean", + "jeep", + "jersey", + "jigsaw puzzle", + "jinrikisha", + "joystick", + "kimono", + "knee pad", + "knot", + "lab coat", + "ladle", + "lampshade", + "laptop", + "lawn mower", + "lens cap", + "letter opener", + "library", + "lifeboat", + "lighter", + "limousine", + "liner", + "lipstick", + "Loafer", + "lotion", + "loudspeaker", + "loupe", + "lumbermill", + "magnetic compass", + "mailbag", + "mailbox", + "maillot", + "maillot", + "manhole cover", + "maraca", + "marimba", + "mask", + "matchstick", + "maypole", + "maze", + "measuring cup", + "medicine chest", + "megalith", + "microphone", + "microwave", + "military uniform", + "milk can", + "minibus", + "miniskirt", + "minivan", + "missile", + "mitten", + "mixing bowl", + "mobile home", + "Model T", + "modem", + "monastery", + "monitor", + "moped", + "mortar", + "mortarboard", + "mosque", + "mosquito net", + "motor scooter", + "mountain bike", + "mountain tent", + "mouse", + "mousetrap", + "moving van", + "muzzle", + "nail", + "neck brace", + "necklace", + "nipple", + "notebook", + "obelisk", + "oboe", + "ocarina", + "odometer", + "oil filter", + "organ", + "oscilloscope", + "overskirt", + "oxcart", + "oxygen mask", + "packet", + "paddle", + "paddlewheel", + "padlock", + "paintbrush", + "pajama", + "palace", + "panpipe", + "paper towel", + "parachute", + "parallel bars", + "park bench", + "parking meter", + "passenger car", + "patio", + "pay-phone", + "pedestal", + "pencil box", + "pencil sharpener", + "perfume", + "Petri dish", + "photocopier", + "pick", + "pickelhaube", + "picket fence", + "pickup", + "pier", + "piggy bank", + "pill bottle", + "pillow", + "ping-pong ball", + "pinwheel", + "pirate", + "pitcher", + "plane", + "planetarium", + "plastic bag", + "plate rack", + "plow", + "plunger", + "Polaroid camera", + "pole", + "police van", + "poncho", + "pool table", + "pop bottle", + "pot", + "potter's wheel", + "power drill", + "prayer rug", + "printer", + "prison", + "projectile", + "projector", + "puck", + "punching bag", + "purse", + "quill", + "quilt", + "racer", + "racket", + "radiator", + "radio", + "radio telescope", + "rain barrel", + "recreational vehicle", + "reel", + "reflex camera", + "refrigerator", + "remote control", + "restaurant", + "revolver", + "rifle", + "rocking chair", + "rotisserie", + "rubber eraser", + "rugby ball", + "rule", + "running shoe", + "safe", + "safety pin", + "saltshaker", + "sandal", + "sarong", + "sax", + "scabbard", + "scale", + "school bus", + "schooner", + "scoreboard", + "screen", + "screw", + "screwdriver", + "seat belt", + "sewing machine", + "shield", + "shoe shop", + "shoji", + "shopping basket", + "shopping cart", + "shovel", + "shower cap", + "shower curtain", + "ski", + "ski mask", + "sleeping bag", + "slide rule", + "sliding door", + "slot", + "snorkel", + "snowmobile", + "snowplow", + "soap dispenser", + "soccer ball", + "sock", + "solar dish", + "sombrero", + "soup bowl", + "space bar", + "space heater", + "space shuttle", + "spatula", + "speedboat", + "spider web", + "spindle", + "sports car", + "spotlight", + "stage", + "steam locomotive", + "steel arch bridge", + "steel drum", + "stethoscope", + "stole", + "stone wall", + "stopwatch", + "stove", + "strainer", + "streetcar", + "stretcher", + "studio couch", + "stupa", + "submarine", + "suit", + "sundial", + "sunglass", + "sunglasses", + "sunscreen", + "suspension bridge", + "swab", + "sweatshirt", + "swimming trunks", + "swing", + "switch", + "syringe", + "table lamp", + "tank", + "tape player", + "teapot", + "teddy", + "television", + "tennis ball", + "thatch", + "theater curtain", + "thimble", + "thresher", + "throne", + "tile roof", + "toaster", + "tobacco shop", + "toilet seat", + "torch", + "totem pole", + "tow truck", + "toyshop", + "tractor", + "trailer truck", + "tray", + "trench coat", + "tricycle", + "trimaran", + "tripod", + "triumphal arch", + "trolleybus", + "trombone", + "tub", + "turnstile", + "typewriter keyboard", + "umbrella", + "unicycle", + "upright", + "vacuum", + "vase", + "vault", + "velvet", + "vending machine", + "vestment", + "viaduct", + "violin", + "volleyball", + "waffle iron", + "wall clock", + "wallet", + "wardrobe", + "warplane", + "washbasin", + "washer", + "water bottle", + "water jug", + "water tower", + "whiskey jug", + "whistle", + "wig", + "window screen", + "window shade", + "Windsor tie", + "wine bottle", + "wing", + "wok", + "wooden spoon", + "wool", + "worm fence", + "wreck", + "yawl", + "yurt", + "web site", + "comic book", + "crossword puzzle", + "street sign", + "traffic light", + "book jacket", + "menu", + "plate", + "guacamole", + "consomme", + "hot pot", + "trifle", + "ice cream", + "ice lolly", + "French loaf", + "bagel", + "pretzel", + "cheeseburger", + "hotdog", + "mashed potato", + "head cabbage", + "broccoli", + "cauliflower", + "zucchini", + "spaghetti squash", + "acorn squash", + "butternut squash", + "cucumber", + "artichoke", + "bell pepper", + "cardoon", + "mushroom", + "Granny Smith", + "strawberry", + "orange", + "lemon", + "fig", + "pineapple", + "banana", + "jackfruit", + "custard apple", + "pomegranate", + "hay", + "carbonara", + "chocolate sauce", + "dough", + "meat loaf", + "pizza", + "potpie", + "burrito", + "red wine", + "espresso", + "cup", + "eggnog", + "alp", + "bubble", + "cliff", + "coral reef", + "geyser", + "lakeside", + "promontory", + "sandbar", + "seashore", + "valley", + "volcano", + "ballplayer", + "groom", + "scuba diver", + "rapeseed", + "daisy", + "yellow lady's slipper", + "corn", + "acorn", + "hip", + "buckeye", + "coral fungus", + "agaric", + "gyromitra", + "stinkhorn", + "earthstar", + "hen-of-the-woods", + "bolete", + "ear", + "toilet tissue", +] diff --git a/torchvision/prototype/models/resnet.py b/torchvision/prototype/models/resnet.py new file mode 100644 index 00000000000..aaa02d5d407 --- /dev/null +++ b/torchvision/prototype/models/resnet.py @@ -0,0 +1,67 @@ +import warnings +from functools import partial +from typing import Any, List, Optional, Type, Union + +from ...models.resnet import BasicBlock, Bottleneck, ResNet +from ..transforms.presets import ImageNetEval +from ._api import Weights, WeightEntry +from ._meta import _IMAGENET_CATEGORIES + + +__all__ = ["ResNet", "ResNet50Weights", "resnet50"] + + +def _resnet( + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + weights: Optional[Weights], + progress: bool, + **kwargs: Any, +) -> ResNet: + if weights is not None: + kwargs["num_classes"] = len(weights.meta["categories"]) + + model = ResNet(block, layers, **kwargs) + + if weights is not None: + model.load_state_dict(weights.state_dict(progress=progress)) + + return model + + +_common_meta = { + "size": (224, 224), + "categories": _IMAGENET_CATEGORIES, +} + + +class ResNet50Weights(Weights): + ImageNet1K_RefV1 = WeightEntry( + url="https://download.pytorch.org/models/resnet50-0676ba61.pth", + transforms=partial(ImageNetEval, crop_size=224), + meta={ + **_common_meta, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification", + "acc@1": 76.130, + "acc@5": 92.862, + }, + ) + ImageNet1K_RefV2 = WeightEntry( + url="https://download.pytorch.org/models/resnet50-tmp.pth", + transforms=partial(ImageNetEval, crop_size=224), + meta={ + **_common_meta, + "recipe": "https://github.com/pytorch/vision/issues/3995", + "acc@1": 80.352, + "acc@5": 95.148, + }, + ) + + +def resnet50(weights: Optional[ResNet50Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: + if "pretrained" in kwargs: + warnings.warn("The argument pretrained is deprecated, please use weights instead.") + weights = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None + weights = ResNet50Weights.verify(weights) + + return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py new file mode 100644 index 00000000000..061b9381e43 --- /dev/null +++ b/torchvision/prototype/transforms/__init__.py @@ -0,0 +1 @@ +from .presets import * diff --git a/torchvision/prototype/transforms/presets.py b/torchvision/prototype/transforms/presets.py new file mode 100644 index 00000000000..56d6fc60e02 --- /dev/null +++ b/torchvision/prototype/transforms/presets.py @@ -0,0 +1,44 @@ +from typing import Tuple + +import torch +from torch import Tensor, nn + +from ... import transforms as T +from ...transforms import functional as F + + +__all__ = ["ConvertImageDtype", "ImageNetEval"] + + +# Allows handling of both PIL and Tensor images +class ConvertImageDtype(nn.Module): + def __init__(self, dtype: torch.dtype) -> None: + super().__init__() + self.dtype = dtype + + def forward(self, img: Tensor) -> Tensor: + if not isinstance(img, Tensor): + img = F.pil_to_tensor(img) + return F.convert_image_dtype(img, self.dtype) + + +class ImageNetEval: + def __init__( + self, + crop_size: int, + resize_size: int = 256, + mean: Tuple[float, ...] = (0.485, 0.456, 0.406), + std: Tuple[float, ...] = (0.229, 0.224, 0.225), + interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR, + ) -> None: + self.transforms = T.Compose( + [ + T.Resize(resize_size, interpolation=interpolation), + T.CenterCrop(crop_size), + ConvertImageDtype(dtype=torch.float), + T.Normalize(mean=mean, std=std), + ] + ) + + def __call__(self, img: Tensor) -> Tensor: + return self.transforms(img)