diff --git a/references/classification/presets.py b/references/classification/presets.py index ce5a6fe414f..c289c3b1c8b 100644 --- a/references/classification/presets.py +++ b/references/classification/presets.py @@ -9,8 +9,11 @@ def __init__(self, crop_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.2 if hflip_prob > 0: trans.append(transforms.RandomHorizontalFlip(hflip_prob)) if auto_augment_policy is not None: - aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy) - trans.append(autoaugment.AutoAugment(policy=aa_policy)) + if auto_augment_policy == "ra": + trans.append(autoaugment.RandAugment()) + else: + aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy) + trans.append(autoaugment.AutoAugment(policy=aa_policy)) trans.extend([ transforms.ToTensor(), transforms.Normalize(mean=mean, std=std), diff --git a/test/test_transforms.py b/test/test_transforms.py index c5cc80ef87e..ca11bf664c1 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1490,6 +1490,18 @@ def test_autoaugment(policy, fill): transform.__repr__() +@pytest.mark.parametrize('num_ops', [1, 2, 3]) +@pytest.mark.parametrize('magnitude', [7, 9, 11]) +@pytest.mark.parametrize('fill', [None, 85, (128, 128, 128)]) +def test_randaugment(num_ops, magnitude, fill): + random.seed(42) + img = Image.open(GRACE_HOPPER) + transform = transforms.RandAugment(num_ops=num_ops, magnitude=magnitude, fill=fill) + for _ in range(100): + img = transform(img) + transform.__repr__() + + def test_random_crop(): height = random.randint(10, 32) * 2 width = random.randint(10, 32) * 2 diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 5081626fec4..c0669987213 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -525,7 +525,6 @@ def test_autoaugment(device, policy, fill): tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device) batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device) - s_transform = None transform = T.AutoAugment(policy=policy, fill=fill) s_transform = torch.jit.script(transform) for _ in range(25): @@ -533,8 +532,24 @@ def test_autoaugment(device, policy, fill): _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) -def test_autoaugment_save(tmpdir): - transform = T.AutoAugment() +@pytest.mark.parametrize('device', cpu_and_gpu()) +@pytest.mark.parametrize('num_ops', [1, 2, 3]) +@pytest.mark.parametrize('magnitude', [7, 9, 11]) +@pytest.mark.parametrize('fill', [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]) +def test_randaugment(device, num_ops, magnitude, fill): + tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device) + batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device) + + transform = T.RandAugment(num_ops=num_ops, magnitude=magnitude, fill=fill) + s_transform = torch.jit.script(transform) + for _ in range(25): + _test_transform_vs_scripted(transform, s_transform, tensor) + _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) + + +@pytest.mark.parametrize('augmentation', [T.AutoAugment, T.RandAugment]) +def test_autoaugment_save(augmentation, tmpdir): + transform = augmentation() s_transform = torch.jit.script(transform) s_transform.save(os.path.join(tmpdir, "t_autoaugment.pt")) diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index 9b2e1ac212e..c8b6a543722 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -7,7 +7,7 @@ from . import functional as F, InterpolationMode -__all__ = ["AutoAugmentPolicy", "AutoAugment"] +__all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment"] def _apply_op(img: Tensor, op_name: str, magnitude: float, @@ -58,6 +58,7 @@ class AutoAugmentPolicy(Enum): SVHN = "svhn" +# FIXME: Eliminate copy-pasted code for fill standardization and _augmentation_space() by moving stuff on a base class class AutoAugment(torch.nn.Module): r"""AutoAugment data augmentation method based on `"AutoAugment: Learning Augmentation Strategies from Data" `_. @@ -85,9 +86,9 @@ def __init__( self.policy = policy self.interpolation = interpolation self.fill = fill - self.transforms = self._get_transforms(policy) + self.policies = self._get_policies(policy) - def _get_transforms( + def _get_policies( self, policy: AutoAugmentPolicy ) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]: @@ -178,9 +179,9 @@ def _get_transforms( else: raise ValueError("The provided policy {} is not recognized.".format(policy)) - def _get_magnitudes(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]: + def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]: return { - # name: (magnitudes, signed) + # op_name: (magnitudes, signed) "ShearX": (torch.linspace(0.0, 0.3, num_bins), True), "ShearY": (torch.linspace(0.0, 0.3, num_bins), True), "TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), @@ -224,11 +225,11 @@ def forward(self, img: Tensor) -> Tensor: elif fill is not None: fill = [float(f) for f in fill] - transform_id, probs, signs = self.get_params(len(self.transforms)) + transform_id, probs, signs = self.get_params(len(self.policies)) - for i, (op_name, p, magnitude_id) in enumerate(self.transforms[transform_id]): + for i, (op_name, p, magnitude_id) in enumerate(self.policies[transform_id]): if probs[i] <= p: - op_meta = self._get_magnitudes(10, F.get_image_size(img)) + op_meta = self._augmentation_space(10, F.get_image_size(img)) magnitudes, signed = op_meta[op_name] magnitude = float(magnitudes[magnitude_id].item()) if magnitude_id is not None else 0.0 if signed and signs[i] == 0: @@ -239,3 +240,87 @@ def forward(self, img: Tensor) -> Tensor: def __repr__(self) -> str: return self.__class__.__name__ + '(policy={}, fill={})'.format(self.policy, self.fill) + + +class RandAugment(torch.nn.Module): + r"""RandAugment data augmentation method based on + `"RandAugment: Practical automated data augmentation with a reduced search space" + `. + If the image is torch Tensor, it should be of type torch.uint8, and it is expected + to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. + If img is PIL Image, it is expected to be in mode "L" or "RGB". + + Args: + num_ops (int): Number of augmentation transformations to apply sequentially. + magnitude (int): Magnitude for all the transformations. + num_magnitude_bins (int): The number of different magnitude values. + interpolation (InterpolationMode): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. + fill (sequence or number, optional): Pixel fill value for the area outside the transformed + image. If given a number, the value is used for all bands respectively. + """ + + def __init__(self, num_ops: int = 2, magnitude: int = 9, num_magnitude_bins: int = 30, + interpolation: InterpolationMode = InterpolationMode.NEAREST, + fill: Optional[List[float]] = None) -> None: + super().__init__() + self.num_ops = num_ops + self.magnitude = magnitude + self.num_magnitude_bins = num_magnitude_bins + self.interpolation = interpolation + self.fill = fill + + def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]: + return { + # op_name: (magnitudes, signed) + "ShearX": (torch.linspace(0.0, 0.3, num_bins), True), + "ShearY": (torch.linspace(0.0, 0.3, num_bins), True), + "TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), + "TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), + "Rotate": (torch.linspace(0.0, 30.0, num_bins), True), + "Brightness": (torch.linspace(0.0, 0.9, num_bins), True), + "Color": (torch.linspace(0.0, 0.9, num_bins), True), + "Contrast": (torch.linspace(0.0, 0.9, num_bins), True), + "Sharpness": (torch.linspace(0.0, 0.9, num_bins), True), + "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False), + "Solarize": (torch.linspace(256.0, 0.0, num_bins), False), + "AutoContrast": (torch.tensor(0.0), False), + "Equalize": (torch.tensor(0.0), False), + "Invert": (torch.tensor(0.0), False), + } + + def forward(self, img: Tensor) -> Tensor: + """ + img (PIL Image or Tensor): Image to be transformed. + Returns: + PIL Image or Tensor: Transformed image. + """ + fill = self.fill + if isinstance(img, Tensor): + if isinstance(fill, (int, float)): + fill = [float(fill)] * F.get_image_num_channels(img) + elif fill is not None: + fill = [float(f) for f in fill] + + for _ in range(self.num_ops): + op_meta = self._augmentation_space(self.num_magnitude_bins, F.get_image_size(img)) + op_index = int(torch.randint(len(op_meta), (1,)).item()) + op_name = list(op_meta.keys())[op_index] + magnitudes, signed = op_meta[op_name] + magnitude = float(magnitudes[self.magnitude].item()) if magnitudes.ndim > 0 else 0.0 + if signed and torch.randint(2, (1,)): + magnitude *= -1.0 + img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill) + + return img + + def __repr__(self) -> str: + s = self.__class__.__name__ + '(' + s += 'num_ops={num_ops}' + s += ', magnitude={magnitude}' + s += ', num_magnitude_bins={num_magnitude_bins}' + s += ', interpolation={interpolation}' + s += ', fill={fill}' + s += ')' + return s.format(**self.__dict__)