-
Notifications
You must be signed in to change notification settings - Fork 13
/
styleTransfer.py
124 lines (85 loc) · 3.03 KB
/
styleTransfer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
from torchvision import transforms , models
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
device = ("cuda" if torch.cuda.is_available() else "cpu")
model = models.vgg19(pretrained=True).features
for p in model.parameters():
p.requires_grad = False
model.to(device)
def model_activations(input,model):
layers = {
'0' : 'conv1_1',
'5' : 'conv2_1',
'10': 'conv3_1',
'19': 'conv4_1',
'21': 'conv4_2',
'28': 'conv5_1'
}
features = {}
x = input
x = x.unsqueeze(0)
for name,layer in model._modules.items():
x = layer(x)
if name in layers:
features[layers[name]] = x
return features
transform = transforms.Compose([transforms.Resize(300),
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
content = Image.open("content.jpg").convert("RGB")
content = transform(content).to(device)
print("COntent shape => ", content.shape)
style = Image.open("style.jpg").convert("RGB")
style = transform(style).to(device)
def imcnvt(image):
x = image.to("cpu").clone().detach().numpy().squeeze()
x = x.transpose(1,2,0)
x = x*np.array((0.5,0.5,0.5)) + np.array((0.5,0.5,0.5))
return np.clip(x,0,1)
fig, (ax1,ax2) = plt.subplots(1,2)
ax1.imshow(imcnvt(content),label = "Content")
ax2.imshow(imcnvt(style),label = "Style")
plt.show()
def gram_matrix(imgfeature):
_,d,h,w = imgfeature.size()
imgfeature = imgfeature.view(d,h*w)
gram_mat = torch.mm(imgfeature,imgfeature.t())
return gram_mat
target = content.clone().requires_grad_(True).to(device)
#set device to cuda if available
print("device = ",device)
style_features = model_activations(style,model)
content_features = model_activations(content,model)
style_wt_meas = {"conv1_1" : 1.0,
"conv2_1" : 0.8,
"conv3_1" : 0.4,
"conv4_1" : 0.2,
"conv5_1" : 0.1}
style_grams = {layer:gram_matrix(style_features[layer]) for layer in style_features}
content_wt = 100
style_wt = 1e8
print_after = 500
epochs = 4000
optimizer = torch.optim.Adam([target],lr=0.007)
for i in range(1,epochs+1):
target_features = model_activations(target,model)
content_loss = torch.mean((content_features['conv4_2']-target_features['conv4_2'])**2)
style_loss = 0
for layer in style_wt_meas:
style_gram = style_grams[layer]
target_gram = target_features[layer]
_,d,w,h = target_gram.shape
target_gram = gram_matrix(target_gram)
style_loss += (style_wt_meas[layer]*torch.mean((target_gram-style_gram)**2))/d*w*h
total_loss = content_wt*content_loss + style_wt*style_loss
if i%10==0:
print("epoch ",i," ", total_loss)
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
if i%print_after == 0:
plt.imshow(imcnvt(target),label="Epoch "+str(i))
plt.show()
plt.imsave(str(i)+'.png',imcnvt(target),format='png')