-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding RandAugment implementation #4348
Changes from all commits
ecdd9ee
89e73f1
bd619da
19ddfbf
5b8d889
d173ac4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,7 @@ | |
|
||
from . import functional as F, InterpolationMode | ||
|
||
__all__ = ["AutoAugmentPolicy", "AutoAugment"] | ||
__all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment"] | ||
|
||
|
||
def _apply_op(img: Tensor, op_name: str, magnitude: float, | ||
|
@@ -58,6 +58,7 @@ class AutoAugmentPolicy(Enum): | |
SVHN = "svhn" | ||
|
||
|
||
# FIXME: Eliminate copy-pasted code for fill standardization and _augmentation_space() by moving stuff on a base class | ||
class AutoAugment(torch.nn.Module): | ||
r"""AutoAugment data augmentation method based on | ||
`"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_. | ||
|
@@ -85,9 +86,9 @@ def __init__( | |
self.policy = policy | ||
self.interpolation = interpolation | ||
self.fill = fill | ||
self.transforms = self._get_transforms(policy) | ||
self.policies = self._get_policies(policy) | ||
|
||
def _get_transforms( | ||
def _get_policies( | ||
self, | ||
policy: AutoAugmentPolicy | ||
) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]: | ||
|
@@ -178,9 +179,9 @@ def _get_transforms( | |
else: | ||
raise ValueError("The provided policy {} is not recognized.".format(policy)) | ||
|
||
def _get_magnitudes(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]: | ||
def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]: | ||
return { | ||
# name: (magnitudes, signed) | ||
# op_name: (magnitudes, signed) | ||
"ShearX": (torch.linspace(0.0, 0.3, num_bins), True), | ||
"ShearY": (torch.linspace(0.0, 0.3, num_bins), True), | ||
"TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), | ||
|
@@ -224,11 +225,11 @@ def forward(self, img: Tensor) -> Tensor: | |
elif fill is not None: | ||
fill = [float(f) for f in fill] | ||
|
||
transform_id, probs, signs = self.get_params(len(self.transforms)) | ||
transform_id, probs, signs = self.get_params(len(self.policies)) | ||
|
||
for i, (op_name, p, magnitude_id) in enumerate(self.transforms[transform_id]): | ||
for i, (op_name, p, magnitude_id) in enumerate(self.policies[transform_id]): | ||
if probs[i] <= p: | ||
op_meta = self._get_magnitudes(10, F.get_image_size(img)) | ||
op_meta = self._augmentation_space(10, F.get_image_size(img)) | ||
magnitudes, signed = op_meta[op_name] | ||
magnitude = float(magnitudes[magnitude_id].item()) if magnitude_id is not None else 0.0 | ||
if signed and signs[i] == 0: | ||
|
@@ -239,3 +240,87 @@ def forward(self, img: Tensor) -> Tensor: | |
|
||
def __repr__(self) -> str: | ||
return self.__class__.__name__ + '(policy={}, fill={})'.format(self.policy, self.fill) | ||
|
||
|
||
class RandAugment(torch.nn.Module): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if it makes sense to inherit from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes indeed. It's a direct copy-paste. The only reason I didn't make it static or inherit from AutoAugment is because I think we haven't nailed the API of the base class yet. I was thinking of keeping only the public parts visible and make changes once we add a couple of methods. |
||
r"""RandAugment data augmentation method based on | ||
`"RandAugment: Practical automated data augmentation with a reduced search space" | ||
<https://arxiv.org/abs/1909.13719>`. | ||
If the image is torch Tensor, it should be of type torch.uint8, and it is expected | ||
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. | ||
If img is PIL Image, it is expected to be in mode "L" or "RGB". | ||
|
||
Args: | ||
num_ops (int): Number of augmentation transformations to apply sequentially. | ||
magnitude (int): Magnitude for all the transformations. | ||
num_magnitude_bins (int): The number of different magnitude values. | ||
interpolation (InterpolationMode): Desired interpolation enum defined by | ||
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. | ||
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. | ||
fill (sequence or number, optional): Pixel fill value for the area outside the transformed | ||
image. If given a number, the value is used for all bands respectively. | ||
""" | ||
|
||
def __init__(self, num_ops: int = 2, magnitude: int = 9, num_magnitude_bins: int = 30, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The num_magnitude_bins should be 31, like for TA, as 0 is also a bin and in the paper the maximal value is 30. That they tried level 31 is definitely weird and, I guess, a typo. |
||
interpolation: InterpolationMode = InterpolationMode.NEAREST, | ||
fill: Optional[List[float]] = None) -> None: | ||
super().__init__() | ||
self.num_ops = num_ops | ||
self.magnitude = magnitude | ||
self.num_magnitude_bins = num_magnitude_bins | ||
self.interpolation = interpolation | ||
self.fill = fill | ||
|
||
def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]: | ||
return { | ||
# op_name: (magnitudes, signed) | ||
"ShearX": (torch.linspace(0.0, 0.3, num_bins), True), | ||
"ShearY": (torch.linspace(0.0, 0.3, num_bins), True), | ||
"TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @SamuelGabriel I noticed in your implementation you restrict the maximum value of Translate to 14.4. I couldn't find a reference of that on the RandAugment paper. I was hoping if you could provide some background info? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I did not use this setting for my RA experiments actually. See for example here: https://github.com/automl/trivialaugment/blob/master/confs/wresnet28x10_svhncore_b128_maxlr.1_ra_fixed.yaml I used There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am actually not sure what is the best strategy to follow here. Any idea? The same problem arises for AutoAugment actually. Should we ask the authors or focus only on 32x32 images or only on ImageNet? For TA it is simpler we use the same setting across datasets. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not 100% sure either. Here I decided to use this approach because of their comment on |
||
"TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), | ||
"Rotate": (torch.linspace(0.0, 30.0, num_bins), True), | ||
"Brightness": (torch.linspace(0.0, 0.9, num_bins), True), | ||
"Color": (torch.linspace(0.0, 0.9, num_bins), True), | ||
"Contrast": (torch.linspace(0.0, 0.9, num_bins), True), | ||
"Sharpness": (torch.linspace(0.0, 0.9, num_bins), True), | ||
"Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False), | ||
"Solarize": (torch.linspace(256.0, 0.0, num_bins), False), | ||
"AutoContrast": (torch.tensor(0.0), False), | ||
"Equalize": (torch.tensor(0.0), False), | ||
"Invert": (torch.tensor(0.0), False), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Invert should not be here, but Identity should be. I believe you replicated the mistake I made in the TA implementation for Vision. Sorry for that. I'll fix it there, too. TA and RA use the same augmentation operations. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, thanks! Indeed this was copied from you. Your implementation heavily inspired how the entire code was refactored here, so thanks a lot for the contribution. |
||
} | ||
|
||
def forward(self, img: Tensor) -> Tensor: | ||
""" | ||
img (PIL Image or Tensor): Image to be transformed. | ||
Returns: | ||
PIL Image or Tensor: Transformed image. | ||
""" | ||
fill = self.fill | ||
if isinstance(img, Tensor): | ||
if isinstance(fill, (int, float)): | ||
fill = [float(fill)] * F.get_image_num_channels(img) | ||
elif fill is not None: | ||
fill = [float(f) for f in fill] | ||
Comment on lines
+300
to
+304
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: might be worth putting this in a helper function, or push it directly to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good call, I'll add this in a helper method for now. This can move to a base class once we nail the API. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The JIT is giving me headaches if I move it on ops. I'll add a TODO to remove duplicate code once we have a base class. |
||
|
||
for _ in range(self.num_ops): | ||
op_meta = self._augmentation_space(self.num_magnitude_bins, F.get_image_size(img)) | ||
op_index = int(torch.randint(len(op_meta), (1,)).item()) | ||
op_name = list(op_meta.keys())[op_index] | ||
magnitudes, signed = op_meta[op_name] | ||
magnitude = float(magnitudes[self.magnitude].item()) if magnitudes.ndim > 0 else 0.0 | ||
if signed and torch.randint(2, (1,)): | ||
magnitude *= -1.0 | ||
img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill) | ||
|
||
return img | ||
|
||
def __repr__(self) -> str: | ||
s = self.__class__.__name__ + '(' | ||
s += 'num_ops={num_ops}' | ||
s += ', magnitude={magnitude}' | ||
s += ', num_magnitude_bins={num_magnitude_bins}' | ||
s += ', interpolation={interpolation}' | ||
s += ', fill={fill}' | ||
s += ')' | ||
return s.format(**self.__dict__) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even thought this is a private method, I decided to use the terminology of the TrivialAugment paper as I think it describes better what we get back (combination of permitted ops and magnitudes) for the given augmentation.