-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathcutpaste.py
154 lines (119 loc) · 5.56 KB
/
cutpaste.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
# From https://github.com/Runinho/pytorch-cutpaste
import random
import math
from torchvision import transforms
import torch
def cut_paste_collate_fn(batch):
# cutPaste return 2 tuples of tuples we convert them into a list of tuples
img_types = list(zip(*batch))
# print(list(zip(*batch)))
return [torch.stack(imgs) for imgs in img_types]
class CutPaste(object):
"""Base class for both cutpaste variants with common operations"""
def __init__(self, colorJitter=0.1, transform=None):
self.transform = transform
if colorJitter is None:
self.colorJitter = None
else:
self.colorJitter = transforms.ColorJitter(brightness = colorJitter,
contrast = colorJitter,
saturation = colorJitter,
hue = colorJitter)
def __call__(self, img):
# apply transforms to both images
if self.transform:
img = self.transform(img)
#org_img = self.transform(org_img)
return img
class CutPasteNormal(CutPaste):
"""Randomly copy one patche from the image and paste it somewere else.
Args:
area_ratio (list): list with 2 floats for maximum and minimum area to cut out
aspect_ratio (float): minimum area ration. Ration is sampled between aspect_ratio and 1/aspect_ratio.
"""
def __init__(self, area_ratio=[0.02,0.15], aspect_ratio=0.3, **kwags):
super(CutPasteNormal, self).__init__(**kwags)
self.area_ratio = area_ratio
self.aspect_ratio = aspect_ratio
def __call__(self, img):
#TODO: we might want to use the pytorch implementation to calculate the patches from https://pytorch.org/vision/stable/_modules/torchvision/transforms/transforms.html#RandomErasing
h = img.size[0]
w = img.size[1]
# ratio between area_ratio[0] and area_ratio[1]
ratio_area = random.uniform(self.area_ratio[0], self.area_ratio[1]) * w * h
# sample in log space
log_ratio = torch.log(torch.tensor((self.aspect_ratio, 1/self.aspect_ratio)))
aspect = torch.exp(
torch.empty(1).uniform_(log_ratio[0], log_ratio[1])
).item()
cut_w = int(round(math.sqrt(ratio_area * aspect)))
cut_h = int(round(math.sqrt(ratio_area / aspect)))
# one might also want to sample from other images. currently we only sample from the image itself
from_location_h = int(random.uniform(0, h - cut_h))
from_location_w = int(random.uniform(0, w - cut_w))
box = [from_location_w, from_location_h, from_location_w + cut_w, from_location_h + cut_h]
patch = img.crop(box)
if self.colorJitter:
patch = self.colorJitter(patch)
to_location_h = int(random.uniform(0, h - cut_h))
to_location_w = int(random.uniform(0, w - cut_w))
insert_box = [to_location_w, to_location_h, to_location_w + cut_w, to_location_h + cut_h]
augmented = img.copy()
augmented.paste(patch, insert_box)
return super().__call__(augmented)
class CutPasteScar(CutPaste):
"""Randomly copy one patche from the image and paste it somewere else.
Args:
width (list): width to sample from. List of [min, max]
height (list): height to sample from. List of [min, max]
rotation (list): rotation to sample from. List of [min, max]
"""
def __init__(self, width=[2,16], height=[10,25], rotation=[-45,45], **kwags):
super(CutPasteScar, self).__init__(**kwags)
self.width = width
self.height = height
self.rotation = rotation
def __call__(self, img):
h = img.size[0]
w = img.size[1]
# cut region
cut_w = random.uniform(*self.width)
cut_h = random.uniform(*self.height)
from_location_h = int(random.uniform(0, h - cut_h))
from_location_w = int(random.uniform(0, w - cut_w))
box = [from_location_w, from_location_h, from_location_w + cut_w, from_location_h + cut_h]
patch = img.crop(box)
if self.colorJitter:
patch = self.colorJitter(patch)
# rotate
rot_deg = random.uniform(*self.rotation)
patch = patch.convert("RGBA").rotate(rot_deg,expand=True)
#paste
to_location_h = int(random.uniform(0, h - patch.size[0]))
to_location_w = int(random.uniform(0, w - patch.size[1]))
mask = patch.split()[-1]
patch = patch.convert("RGB")
augmented = img.copy()
augmented.paste(patch, (to_location_w, to_location_h), mask=mask)
return super().__call__(augmented)
class CutPasteUnion(object):
def __init__(self, **kwags):
self.normal = CutPasteNormal(**kwags)
self.scar = CutPasteScar(**kwags)
def __call__(self, img):
toImg = transforms.ToPILImage()
toTensor = transforms.ToTensor()
img = toImg(img)
r = random.uniform(0, 1)
if r < 0.5:
return toTensor(self.normal(img))
else:
return toTensor(self.scar(img))
class CutPaste3Way(object):
def __init__(self, **kwags):
self.normal = CutPasteNormal(**kwags)
self.scar = CutPasteScar(**kwags)
def __call__(self, img):
org, cutpaste_normal = self.normal(img)
_, cutpaste_scar = self.scar(img)
return org, cutpaste_normal, cutpaste_scar