Skip to content

Commit

Permalink
Doing batch hflip
Browse files Browse the repository at this point in the history
  • Loading branch information
ancestor-mithril committed Nov 29, 2024
1 parent 9f97b0c commit 1de605c
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 12 deletions.
4 changes: 3 additions & 1 deletion utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@ def identity(x):


class CachedDataset(Dataset):
def __init__(self, dataset, transforms=None, num_classes=None, cache=True):
def __init__(self, dataset, transforms=None, num_classes=None, cache=True, batch_transforms=None):
if cache:
self.data = tuple([x for x in dataset])
else:
self.data = dataset
self.transforms = transforms
self.num_classes = num_classes
self.batch_transforms = batch_transforms

def __len__(self):
return len(self.data)
Expand Down Expand Up @@ -75,6 +76,7 @@ def init_dataset(args):
transforms=transforms.train_runtime(),
num_classes=num_classes,
cache=cache_train_dataset,
batch_transforms=transforms.batch_transforms(),
)

test_dataset = dataset_fn(train=False, transform=transforms.test_cached())
Expand Down
20 changes: 13 additions & 7 deletions utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch
from timed_decorator.simple_timed import timed
from torch import GradScaler
from torch import GradScaler, Tensor
from torch.backends import cudnn
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
Expand Down Expand Up @@ -168,9 +168,9 @@ def train(self):
total_loss = 0.0

for inputs, targets in self.train_loader:
inputs, targets = inputs.to(self.device, non_blocking=True), targets.to(
self.device, non_blocking=True
)
inputs = inputs.to(self.device, non_blocking=True)
targets = targets.to(self.device, non_blocking=True)
inputs = self.maybe_batch_transforms(inputs)
with torch.autocast(self.device.type, enabled=self.args.half):
outputs = self.model(inputs)
loss = self.criterion(outputs, targets)
Expand All @@ -195,10 +195,11 @@ def val(self):
total_loss = 0.0

for inputs, targets in self.test_loader:
inputs, targets = inputs.to(self.device, non_blocking=True), targets.to(
self.device, non_blocking=True
)
inputs = inputs.to(self.device, non_blocking=True)
targets = targets.to(self.device, non_blocking=True)

with torch.autocast(self.device.type, enabled=self.args.half):
# TODO: put TTA in batch
outputs = self.model(inputs)
loss = self.criterion(outputs, targets)
if self.args.tta:
Expand Down Expand Up @@ -282,3 +283,8 @@ def maybe_clip(self):
if self.args.clip_value is not None:
self.scaler.unscale_(self.optimizer)
clip_grad_norm_(self.model.parameters(), self.args.clip_value)

def maybe_batch_transforms(self, x: Tensor) -> Tensor:
if self.train_dataset.batch_transforms is not None:
return self.train_dataset.batch_transforms(x)
return x
30 changes: 26 additions & 4 deletions utils/transforms.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
from abc import ABC, abstractmethod

import torch
from torch import nn, Tensor
from torchvision.transforms import v2


class BatchHorizontalFlip(nn.Module):
def __init__(self, p: int = 0.5):
super().__init__()
self.p = p

def forward(self, x: Tensor) -> Tensor:
flip_mask = (torch.rand(len(x), device=x.device) < self.p).view(-1, 1, 1, 1)
return torch.where(flip_mask, x.flip(-1), x)


class DatasetTransforms(ABC):
@abstractmethod
def train_cached(self):
Expand All @@ -21,6 +32,9 @@ def test_cached(self):
def test_runtime(self):
pass

def batch_transforms(self):
return None


class MNISTTransforms(DatasetTransforms):
def __init__(self, args):
Expand All @@ -38,7 +52,6 @@ def train_cached(self):
def train_runtime(self):
return v2.Compose(
[
v2.RandomHorizontalFlip(),
v2.RandomAffine(degrees=20, translate=(0.1, 0.1), scale=(0.9, 1.1)),
v2.ColorJitter(brightness=0.2, contrast=0.2),
self.normalize,
Expand All @@ -53,6 +66,11 @@ def test_cached(self):
def test_runtime(self):
return None

def batch_transforms(self):
return v2.Compose([
BatchHorizontalFlip(),
])


class CifarTransforms(DatasetTransforms):
def __init__(self, args):
Expand All @@ -74,7 +92,6 @@ def train_runtime(self):
v2.RandomCrop(
32, padding=4, fill=0 if self.args.fill is None else self.args.fill
),
v2.RandomHorizontalFlip(),
]

if self.args.autoaug:
Expand Down Expand Up @@ -103,18 +120,23 @@ def test_cached(self):
def test_runtime(self):
return None

def batch_transforms(self):
return v2.Compose([
BatchHorizontalFlip(),
])


class FashionMNISTTransforms(CifarTransforms):
def __init__(self, args):
super().__init__(args)
self.normalize = v2.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
# FashionMNIST must not have normalize inplace, otherwise it fails because it is a single channel image
# FashionMNIST must not have inplace normalize, otherwise it fails because it is a single channel image


def init_transforms(args) -> DatasetTransforms:
if args.dataset in ("cifar10", "cifar100", "cifar100noisy"):
return CifarTransforms(args)
if args.dataset in ("FashionMNIST", ):
if args.dataset in ("FashionMNIST",):
return FashionMNISTTransforms(args)
if args.dataset in ("MNIST", "DirtyMNIST"):
return MNISTTransforms(args)
Expand Down

0 comments on commit 1de605c

Please sign in to comment.