Skip to content

Commit

Permalink
Adding RandAugment implementation (#4348)
Browse files Browse the repository at this point in the history
* Adding randaugment implementation

* Refactoring.

* Adding num_magnitude_bins.

* Adding FIXME.
  • Loading branch information
datumbox authored Sep 2, 2021
1 parent f52ddb0 commit 5a81554
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 13 deletions.
7 changes: 5 additions & 2 deletions references/classification/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@ def __init__(self, crop_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.2
if hflip_prob > 0:
trans.append(transforms.RandomHorizontalFlip(hflip_prob))
if auto_augment_policy is not None:
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
trans.append(autoaugment.AutoAugment(policy=aa_policy))
if auto_augment_policy == "ra":
trans.append(autoaugment.RandAugment())
else:
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
trans.append(autoaugment.AutoAugment(policy=aa_policy))
trans.extend([
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
Expand Down
12 changes: 12 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1490,6 +1490,18 @@ def test_autoaugment(policy, fill):
transform.__repr__()


@pytest.mark.parametrize('num_ops', [1, 2, 3])
@pytest.mark.parametrize('magnitude', [7, 9, 11])
@pytest.mark.parametrize('fill', [None, 85, (128, 128, 128)])
def test_randaugment(num_ops, magnitude, fill):
random.seed(42)
img = Image.open(GRACE_HOPPER)
transform = transforms.RandAugment(num_ops=num_ops, magnitude=magnitude, fill=fill)
for _ in range(100):
img = transform(img)
transform.__repr__()


def test_random_crop():
height = random.randint(10, 32) * 2
width = random.randint(10, 32) * 2
Expand Down
21 changes: 18 additions & 3 deletions test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,16 +525,31 @@ def test_autoaugment(device, policy, fill):
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)

s_transform = None
transform = T.AutoAugment(policy=policy, fill=fill)
s_transform = torch.jit.script(transform)
for _ in range(25):
_test_transform_vs_scripted(transform, s_transform, tensor)
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)


def test_autoaugment_save(tmpdir):
transform = T.AutoAugment()
@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('num_ops', [1, 2, 3])
@pytest.mark.parametrize('magnitude', [7, 9, 11])
@pytest.mark.parametrize('fill', [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1])
def test_randaugment(device, num_ops, magnitude, fill):
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)

transform = T.RandAugment(num_ops=num_ops, magnitude=magnitude, fill=fill)
s_transform = torch.jit.script(transform)
for _ in range(25):
_test_transform_vs_scripted(transform, s_transform, tensor)
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)


@pytest.mark.parametrize('augmentation', [T.AutoAugment, T.RandAugment])
def test_autoaugment_save(augmentation, tmpdir):
transform = augmentation()
s_transform = torch.jit.script(transform)
s_transform.save(os.path.join(tmpdir, "t_autoaugment.pt"))

Expand Down
101 changes: 93 additions & 8 deletions torchvision/transforms/autoaugment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from . import functional as F, InterpolationMode

__all__ = ["AutoAugmentPolicy", "AutoAugment"]
__all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment"]


def _apply_op(img: Tensor, op_name: str, magnitude: float,
Expand Down Expand Up @@ -58,6 +58,7 @@ class AutoAugmentPolicy(Enum):
SVHN = "svhn"


# FIXME: Eliminate copy-pasted code for fill standardization and _augmentation_space() by moving stuff on a base class
class AutoAugment(torch.nn.Module):
r"""AutoAugment data augmentation method based on
`"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_.
Expand Down Expand Up @@ -85,9 +86,9 @@ def __init__(
self.policy = policy
self.interpolation = interpolation
self.fill = fill
self.transforms = self._get_transforms(policy)
self.policies = self._get_policies(policy)

def _get_transforms(
def _get_policies(
self,
policy: AutoAugmentPolicy
) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]:
Expand Down Expand Up @@ -178,9 +179,9 @@ def _get_transforms(
else:
raise ValueError("The provided policy {} is not recognized.".format(policy))

def _get_magnitudes(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]:
def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]:
return {
# name: (magnitudes, signed)
# op_name: (magnitudes, signed)
"ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
Expand Down Expand Up @@ -224,11 +225,11 @@ def forward(self, img: Tensor) -> Tensor:
elif fill is not None:
fill = [float(f) for f in fill]

transform_id, probs, signs = self.get_params(len(self.transforms))
transform_id, probs, signs = self.get_params(len(self.policies))

for i, (op_name, p, magnitude_id) in enumerate(self.transforms[transform_id]):
for i, (op_name, p, magnitude_id) in enumerate(self.policies[transform_id]):
if probs[i] <= p:
op_meta = self._get_magnitudes(10, F.get_image_size(img))
op_meta = self._augmentation_space(10, F.get_image_size(img))
magnitudes, signed = op_meta[op_name]
magnitude = float(magnitudes[magnitude_id].item()) if magnitude_id is not None else 0.0
if signed and signs[i] == 0:
Expand All @@ -239,3 +240,87 @@ def forward(self, img: Tensor) -> Tensor:

def __repr__(self) -> str:
return self.__class__.__name__ + '(policy={}, fill={})'.format(self.policy, self.fill)


class RandAugment(torch.nn.Module):
r"""RandAugment data augmentation method based on
`"RandAugment: Practical automated data augmentation with a reduced search space"
<https://arxiv.org/abs/1909.13719>`.
If the image is torch Tensor, it should be of type torch.uint8, and it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
Args:
num_ops (int): Number of augmentation transformations to apply sequentially.
magnitude (int): Magnitude for all the transformations.
num_magnitude_bins (int): The number of different magnitude values.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively.
"""

def __init__(self, num_ops: int = 2, magnitude: int = 9, num_magnitude_bins: int = 30,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None) -> None:
super().__init__()
self.num_ops = num_ops
self.magnitude = magnitude
self.num_magnitude_bins = num_magnitude_bins
self.interpolation = interpolation
self.fill = fill

def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]:
return {
# op_name: (magnitudes, signed)
"ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
"TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
"Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
"Brightness": (torch.linspace(0.0, 0.9, num_bins), True),
"Color": (torch.linspace(0.0, 0.9, num_bins), True),
"Contrast": (torch.linspace(0.0, 0.9, num_bins), True),
"Sharpness": (torch.linspace(0.0, 0.9, num_bins), True),
"Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False),
"Solarize": (torch.linspace(256.0, 0.0, num_bins), False),
"AutoContrast": (torch.tensor(0.0), False),
"Equalize": (torch.tensor(0.0), False),
"Invert": (torch.tensor(0.0), False),
}

def forward(self, img: Tensor) -> Tensor:
"""
img (PIL Image or Tensor): Image to be transformed.
Returns:
PIL Image or Tensor: Transformed image.
"""
fill = self.fill
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * F.get_image_num_channels(img)
elif fill is not None:
fill = [float(f) for f in fill]

for _ in range(self.num_ops):
op_meta = self._augmentation_space(self.num_magnitude_bins, F.get_image_size(img))
op_index = int(torch.randint(len(op_meta), (1,)).item())
op_name = list(op_meta.keys())[op_index]
magnitudes, signed = op_meta[op_name]
magnitude = float(magnitudes[self.magnitude].item()) if magnitudes.ndim > 0 else 0.0
if signed and torch.randint(2, (1,)):
magnitude *= -1.0
img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)

return img

def __repr__(self) -> str:
s = self.__class__.__name__ + '('
s += 'num_ops={num_ops}'
s += ', magnitude={magnitude}'
s += ', num_magnitude_bins={num_magnitude_bins}'
s += ', interpolation={interpolation}'
s += ', fill={fill}'
s += ')'
return s.format(**self.__dict__)

0 comments on commit 5a81554

Please sign in to comment.