Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding RandAugment implementation #4348

Merged
merged 6 commits into from
Sep 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions references/classification/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@ def __init__(self, crop_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.2
if hflip_prob > 0:
trans.append(transforms.RandomHorizontalFlip(hflip_prob))
if auto_augment_policy is not None:
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
trans.append(autoaugment.AutoAugment(policy=aa_policy))
if auto_augment_policy == "ra":
trans.append(autoaugment.RandAugment())
else:
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
trans.append(autoaugment.AutoAugment(policy=aa_policy))
trans.extend([
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
Expand Down
12 changes: 12 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1490,6 +1490,18 @@ def test_autoaugment(policy, fill):
transform.__repr__()


@pytest.mark.parametrize('num_ops', [1, 2, 3])
@pytest.mark.parametrize('magnitude', [7, 9, 11])
@pytest.mark.parametrize('fill', [None, 85, (128, 128, 128)])
def test_randaugment(num_ops, magnitude, fill):
random.seed(42)
img = Image.open(GRACE_HOPPER)
transform = transforms.RandAugment(num_ops=num_ops, magnitude=magnitude, fill=fill)
for _ in range(100):
img = transform(img)
transform.__repr__()


def test_random_crop():
height = random.randint(10, 32) * 2
width = random.randint(10, 32) * 2
Expand Down
21 changes: 18 additions & 3 deletions test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,16 +525,31 @@ def test_autoaugment(device, policy, fill):
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)

s_transform = None
transform = T.AutoAugment(policy=policy, fill=fill)
s_transform = torch.jit.script(transform)
for _ in range(25):
_test_transform_vs_scripted(transform, s_transform, tensor)
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)


def test_autoaugment_save(tmpdir):
transform = T.AutoAugment()
@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('num_ops', [1, 2, 3])
@pytest.mark.parametrize('magnitude', [7, 9, 11])
@pytest.mark.parametrize('fill', [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1])
def test_randaugment(device, num_ops, magnitude, fill):
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)

transform = T.RandAugment(num_ops=num_ops, magnitude=magnitude, fill=fill)
s_transform = torch.jit.script(transform)
for _ in range(25):
_test_transform_vs_scripted(transform, s_transform, tensor)
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)


@pytest.mark.parametrize('augmentation', [T.AutoAugment, T.RandAugment])
def test_autoaugment_save(augmentation, tmpdir):
transform = augmentation()
s_transform = torch.jit.script(transform)
s_transform.save(os.path.join(tmpdir, "t_autoaugment.pt"))

Expand Down
101 changes: 93 additions & 8 deletions torchvision/transforms/autoaugment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from . import functional as F, InterpolationMode

__all__ = ["AutoAugmentPolicy", "AutoAugment"]
__all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment"]


def _apply_op(img: Tensor, op_name: str, magnitude: float,
Expand Down Expand Up @@ -58,6 +58,7 @@ class AutoAugmentPolicy(Enum):
SVHN = "svhn"


# FIXME: Eliminate copy-pasted code for fill standardization and _augmentation_space() by moving stuff on a base class
class AutoAugment(torch.nn.Module):
r"""AutoAugment data augmentation method based on
`"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_.
Expand Down Expand Up @@ -85,9 +86,9 @@ def __init__(
self.policy = policy
self.interpolation = interpolation
self.fill = fill
self.transforms = self._get_transforms(policy)
self.policies = self._get_policies(policy)

def _get_transforms(
def _get_policies(
self,
policy: AutoAugmentPolicy
) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]:
Expand Down Expand Up @@ -178,9 +179,9 @@ def _get_transforms(
else:
raise ValueError("The provided policy {} is not recognized.".format(policy))

def _get_magnitudes(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]:
def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]:
Copy link
Contributor Author

@datumbox datumbox Sep 1, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even thought this is a private method, I decided to use the terminology of the TrivialAugment paper as I think it describes better what we get back (combination of permitted ops and magnitudes) for the given augmentation.

return {
# name: (magnitudes, signed)
# op_name: (magnitudes, signed)
"ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
Expand Down Expand Up @@ -224,11 +225,11 @@ def forward(self, img: Tensor) -> Tensor:
elif fill is not None:
fill = [float(f) for f in fill]

transform_id, probs, signs = self.get_params(len(self.transforms))
transform_id, probs, signs = self.get_params(len(self.policies))

for i, (op_name, p, magnitude_id) in enumerate(self.transforms[transform_id]):
for i, (op_name, p, magnitude_id) in enumerate(self.policies[transform_id]):
if probs[i] <= p:
op_meta = self._get_magnitudes(10, F.get_image_size(img))
op_meta = self._augmentation_space(10, F.get_image_size(img))
magnitudes, signed = op_meta[op_name]
magnitude = float(magnitudes[magnitude_id].item()) if magnitude_id is not None else 0.0
if signed and signs[i] == 0:
Expand All @@ -239,3 +240,87 @@ def forward(self, img: Tensor) -> Tensor:

def __repr__(self) -> str:
return self.__class__.__name__ + '(policy={}, fill={})'.format(self.policy, self.fill)


class RandAugment(torch.nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it makes sense to inherit from AutoAugment and override only the _augmentation_space method?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes indeed. It's a direct copy-paste. The only reason I didn't make it static or inherit from AutoAugment is because I think we haven't nailed the API of the base class yet. I was thinking of keeping only the public parts visible and make changes once we add a couple of methods.

r"""RandAugment data augmentation method based on
`"RandAugment: Practical automated data augmentation with a reduced search space"
<https://arxiv.org/abs/1909.13719>`.
If the image is torch Tensor, it should be of type torch.uint8, and it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".

Args:
num_ops (int): Number of augmentation transformations to apply sequentially.
magnitude (int): Magnitude for all the transformations.
num_magnitude_bins (int): The number of different magnitude values.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively.
"""

def __init__(self, num_ops: int = 2, magnitude: int = 9, num_magnitude_bins: int = 30,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • N=2, M=9 are the best ImageNet values for ResNet50 (see A.2.3 on paper).
  • num_magnitude_bins=30 because the majority of the experiments on the paper used this value. Weirdly section A.2.3 mentions trying the level 31 for EfficientNet B7.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The num_magnitude_bins should be 31, like for TA, as 0 is also a bin and in the paper the maximal value is 30. That they tried level 31 is definitely weird and, I guess, a typo.

interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None) -> None:
super().__init__()
self.num_ops = num_ops
self.magnitude = magnitude
self.num_magnitude_bins = num_magnitude_bins
self.interpolation = interpolation
self.fill = fill

def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]:
return {
# op_name: (magnitudes, signed)
"ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
Copy link
Contributor Author

@datumbox datumbox Sep 1, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SamuelGabriel I noticed in your implementation you restrict the maximum value of Translate to 14.4. I couldn't find a reference of that on the RandAugment paper. I was hoping if you could provide some background info?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not use this setting for my RA experiments actually. See for example here: https://github.com/automl/trivialaugment/blob/master/confs/wresnet28x10_svhncore_b128_maxlr.1_ra_fixed.yaml

I used fixed_standard which is the search space described in the RA paper, but slightly different from that in the implementation of the significant parts of this (https://github.com/tensorflow/models/tree/fd34f711f319d8c6fe85110d9df6e1784cc5a6ca/research/autoaugment) AutoAugment implementation (which RA follows with its augmentation space), therefore I call it fixed. The setting you speak about is for their ImageNet experiment, as my re-implementations of AA/RA where on 32x32 images (CIFAR/SVHN), I followed the implementation above. Here they set the translation to 10: https://github.com/tensorflow/models/blob/fd34f711f319d8c6fe85110d9df6e1784cc5a6ca/research/autoaugment/augmentation_transforms.py#L319 They do not use the same augmentation space across datasets...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am actually not sure what is the best strategy to follow here. Any idea? The same problem arises for AutoAugment actually. Should we ask the authors or focus only on 32x32 images or only on ImageNet? For TA it is simpler we use the same setting across datasets.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not 100% sure either. Here I decided to use this approach because of their comment on Table 6 on the AA paper. This also means that if you are a 32x32 image as in the case of CIFAR, your Translate max value would be 14.5, which is similar but on equal to yours and hence it sparked my interest on how you derived it. Not sure if this subpixel diff matters here.

"TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
"Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
"Brightness": (torch.linspace(0.0, 0.9, num_bins), True),
"Color": (torch.linspace(0.0, 0.9, num_bins), True),
"Contrast": (torch.linspace(0.0, 0.9, num_bins), True),
"Sharpness": (torch.linspace(0.0, 0.9, num_bins), True),
"Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False),
"Solarize": (torch.linspace(256.0, 0.0, num_bins), False),
"AutoContrast": (torch.tensor(0.0), False),
"Equalize": (torch.tensor(0.0), False),
"Invert": (torch.tensor(0.0), False),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Invert should not be here, but Identity should be. I believe you replicated the mistake I made in the TA implementation for Vision. Sorry for that. I'll fix it there, too. TA and RA use the same augmentation operations.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, thanks! Indeed this was copied from you. Your implementation heavily inspired how the entire code was refactored here, so thanks a lot for the contribution.

}

def forward(self, img: Tensor) -> Tensor:
"""
img (PIL Image or Tensor): Image to be transformed.
Returns:
PIL Image or Tensor: Transformed image.
"""
fill = self.fill
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * F.get_image_num_channels(img)
elif fill is not None:
fill = [float(f) for f in fill]
Comment on lines +300 to +304
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: might be worth putting this in a helper function, or push it directly to _apply_op maybe?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call, I'll add this in a helper method for now. This can move to a base class once we nail the API.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The JIT is giving me headaches if I move it on ops. I'll add a TODO to remove duplicate code once we have a base class.


for _ in range(self.num_ops):
op_meta = self._augmentation_space(self.num_magnitude_bins, F.get_image_size(img))
op_index = int(torch.randint(len(op_meta), (1,)).item())
op_name = list(op_meta.keys())[op_index]
magnitudes, signed = op_meta[op_name]
magnitude = float(magnitudes[self.magnitude].item()) if magnitudes.ndim > 0 else 0.0
if signed and torch.randint(2, (1,)):
magnitude *= -1.0
img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)

return img

def __repr__(self) -> str:
s = self.__class__.__name__ + '('
s += 'num_ops={num_ops}'
s += ', magnitude={magnitude}'
s += ', num_magnitude_bins={num_magnitude_bins}'
s += ', interpolation={interpolation}'
s += ', fill={fill}'
s += ')'
return s.format(**self.__dict__)