-
Notifications
You must be signed in to change notification settings - Fork 1
/
transform.py
65 lines (55 loc) · 2.14 KB
/
transform.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
from torch import nn
from kornia.augmentation import RandomAffine, RandomCrop, RandomResizedCrop, Resize
from kornia.filters import GaussianBlur2d
from config import args
class Transforms(object):
""" Reference : Data-Efficient Reinforcement Learning with Self-Predictive Representations
Thanks to Repo: GitHub - mila-iqia/spr: Code for "Data-Efficient Reinforcement Learning with Self-Predictive Represe
"""
def __init__(self, augmentation=None, shift_delta=4, image_shape=args.resolution):
if augmentation is None:
augmentation = ['shift', 'intensity']
self.augmentation = augmentation
self.transforms = []
for aug in self.augmentation:
if aug == "affine":
transformation = RandomAffine(5, (.14, .14), (.9, 1.1), (-5, 5))
elif aug == "crop":
transformation = RandomCrop(image_shape)
elif aug == "rrc":
transformation = RandomResizedCrop((100, 100), (0.8, 1))
elif aug == "blur":
transformation = GaussianBlur2d((5, 5), (1.5, 1.5))
elif aug == "shift":
transformation = nn.Sequential(nn.ReplicationPad2d(shift_delta), RandomCrop(image_shape))
elif aug == "intensity":
transformation = Intensity(scale=0.05)
elif aug == "resize":
transformation = Resize(image_shape)
elif aug == "none":
transformation = nn.Identity()
else:
raise NotImplementedError()
self.transforms.append(transformation)
@staticmethod
def apply_transforms(transforms, image):
for transform in transforms:
image = transform(image)
return image
@torch.no_grad()
def __call__(self, images):
# images = images.float() / 255. if images.dtype == torch.uint8 else images
flat_images = images.reshape(-1, *images.shape[-3:])
processed_images = self.apply_transforms(self.transforms, flat_images)
processed_images = processed_images.view(*images.shape[:-3],
*processed_images.shape[1:])
return processed_images
class Intensity(nn.Module):
def __init__(self, scale):
super().__init__()
self.scale = scale
def forward(self, x):
r = torch.randn((x.size(0), 1, 1, 1), device=x.device)
noise = 1.0 + (self.scale * r.clamp(-2.0, 2.0))
return x * noise