-
Notifications
You must be signed in to change notification settings - Fork 2
/
style_utils.py
121 lines (101 loc) · 3.78 KB
/
style_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import cv2
import numpy as np
import matplotlib.pyplot as plt
import torch
from torchvision import transforms, datasets
# Gram Matrix
def gram(tensor):
B, C, H, W = tensor.shape
x = tensor.view(B, C, H*W)
x_t = x.transpose(1, 2)
return torch.bmm(x, x_t) / (C*H*W)
# Load image file
def load_image(path):
# Images loaded as BGR
img = cv2.imread(path)
return img
# Show image
def show(img):
# Convert from BGR to RGB
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# imshow() only accepts float [0,1] or int [0,255]
img = np.array(img/255).clip(0,1)
plt.figure(figsize=(10, 5))
plt.imshow(img)
plt.show()
def saveimg(img, image_path):
img = img.clip(0, 255)
cv2.imwrite(image_path, img)
# Preprocessing ~ Image to Tensor
def itot(img, max_size=None):
# Rescale the image
if (max_size==None):
itot_t = transforms.Compose([
#transforms.ToPILImage(),
transforms.ToTensor(),
transforms.Lambda(lambda x: x.mul(255))
])
else:
#H, W, C = img.shape
#image_size = tuple([int((float(max_size) / max([H,W]))*x) for x in [H, W]])
itot_t = transforms.Compose([
transforms.ToPILImage(),
#transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Lambda(lambda x: x.mul(255))
])
# Convert image to tensor
tensor = itot_t(img)
# Add the batch_size dimension
tensor = tensor.unsqueeze(dim=0)
return tensor
# Preprocessing ~ Tensor to Image
def ttoi(tensor):
# Add the means
#ttoi_t = transforms.Compose([
# transforms.Normalize([-103.939, -116.779, -123.68],[1,1,1])])
# Remove the batch_size dimension
tensor = tensor.squeeze()
#img = ttoi_t(tensor)
img = tensor.cpu().numpy()
# Transpose from [C, H, W] -> [H, W, C]
img = img.transpose(1, 2, 0)
return img
def transfer_color(src, dest):
"""
Transfer Color using YIQ colorspace. Useful in preserving colors in style transfer.
This method assumes inputs of shape [Height, Width, Channel] in BGR Color Space
"""
src, dest = src.clip(0,255), dest.clip(0,255)
# Resize src to dest's size
H,W,_ = src.shape
dest = cv2.resize(dest, dsize=(W, H), interpolation=cv2.INTER_CUBIC)
dest_gray = cv2.cvtColor(dest, cv2.COLOR_BGR2GRAY) #1 Extract the Destination's luminance
src_yiq = cv2.cvtColor(src, cv2.COLOR_BGR2YCrCb) #2 Convert the Source from BGR to YIQ/YCbCr
src_yiq[...,0] = dest_gray #3 Combine Destination's luminance and Source's IQ/CbCr
return cv2.cvtColor(src_yiq, cv2.COLOR_YCrCb2BGR).clip(0,255) #4 Convert new image from YIQ back to BGR
def plot_loss_hist(c_loss, s_loss, total_loss, title="Loss History"):
x = [i for i in range(len(total_loss))]
plt.figure(figsize=[10, 6])
plt.plot(x, c_loss, label="Content Loss")
plt.plot(x, s_loss, label="Style Loss")
plt.plot(x, total_loss, label="Total Loss")
plt.legend()
plt.xlabel('Every 500 iterations')
plt.ylabel('Loss')
plt.title(title)
plt.show()
class ImageFolderWithPaths(datasets.ImageFolder):
"""Custom dataset that includes image file paths.
Extends torchvision.datasets.ImageFolder()
Reference: https://discuss.pytorch.org/t/dataloader-filenames-in-each-batch/4212/2
"""
# override the __getitem__ method. this is the method dataloader calls
def __getitem__(self, index):
# this is what ImageFolder normally returns
original_tuple = super(ImageFolderWithPaths, self).__getitem__(index)
# the image file path
path = self.imgs[index][0]
# make a new tuple that includes original and the path
tuple_with_path = (*original_tuple, path)
return tuple_with_path