Skip to content

Commit

Permalink
Added batched cutout instead of RandomErasing
Browse files Browse the repository at this point in the history
  • Loading branch information
ancestor-mithril committed Nov 30, 2024
1 parent 9274e6f commit ce832b9
Showing 1 changed file with 40 additions and 2 deletions.
42 changes: 40 additions & 2 deletions utils/transforms.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import math
from abc import ABC, abstractmethod
from typing import Sequence, Callable
from typing import Sequence, Callable, Union

import torch
from torch import nn, Tensor
Expand All @@ -16,6 +17,42 @@ def forward(self, x: Tensor) -> Tensor:
return torch.where(flip_mask, x.flip(-1), x)


def make_random_square_masks(x: Tensor, cutout_size: int) -> Tensor:
batch_size, _, h, w = x.shape

# seed top-left corners of squares to cutout boxes from, in one dimension each
corner_y = torch.randint(0, h - cutout_size + 1, size=(batch_size,), device=x.device)
corner_x = torch.randint(0, w - cutout_size + 1, size=(batch_size,), device=x.device)

# measure distance, using the center as a reference point
corner_y_dists = torch.arange(h, device=x.device).view(1, 1, h, 1) - corner_y.view(-1, 1, 1, 1)
corner_x_dists = torch.arange(w, device=x.device).view(1, 1, 1, w) - corner_x.view(-1, 1, 1, 1)

mask_y = (corner_y_dists >= 0) * (corner_y_dists < cutout_size)
mask_x = (corner_x_dists >= 0) * (corner_x_dists < cutout_size)

final_mask = mask_y * mask_x

return final_mask


class BatchCutout(nn.Module):
# Inspired from https://github.com/KellerJordan/cifar10-airbench
# https://arxiv.org/abs/2404.00498: 94% on CIFAR-10 in 3.29 Seconds on a Single GPU
def __init__(self, p: float = 0.5, scale: Sequence[float] = (0.0, 0.33), value: float = 0.0):
# TODO: Add ratio
super().__init__()
self.p = p
self.scale = scale
self.value = value

def forward(self, x: Tensor) -> Tensor:
size = x.shape[-1]
cutout_size = torch.randint(math.floor(self.scale[0] * size), math.ceil(self.scale[1] * size), size=(1,)).item()
cutout_mask = make_random_square_masks(x, cutout_size)
return x.masked_fill_(cutout_mask, self.value)


class AlternativeHorizontalFlip(nn.Module):
# Inspired from https://github.com/KellerJordan/cifar10-airbench
# https://arxiv.org/abs/2404.00498: 94% on CIFAR-10 in 3.29 Seconds on a Single GPU
Expand Down Expand Up @@ -153,7 +190,8 @@ def test_cached(self):

def create_cutout(self):
fill_value = 0 if self.args.fill is None else self.args.fill
return v2.RandomErasing(scale=(0.05, 0.15), value=fill_value, inplace=True)
# return v2.RandomErasing(scale=(0.05, 0.15), value=fill_value, inplace=True)
return torch.jit.script(BatchCutout(value=fill_value))

def batch_transforms_cpu(self):
transforms = [
Expand Down

0 comments on commit ce832b9

Please sign in to comment.