-
Notifications
You must be signed in to change notification settings - Fork 84
/
vgg_loss.py
32 lines (29 loc) · 1.06 KB
/
vgg_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
import torch.nn as nn
import torchvision
class VGGLoss(nn.Module):
"""
Part of pre-trained VGG16. This is used in case we want perceptual loss instead of Mean Square Error loss.
See for instance https://arxiv.org/abs/1603.08155
"""
def __init__(self, block_no: int, layer_within_block: int, use_batch_norm_vgg: bool):
super(VGGLoss, self).__init__()
if use_batch_norm_vgg:
vgg16 = torchvision.models.vgg16_bn(pretrained=True)
else:
vgg16 = torchvision.models.vgg16(pretrained=True)
curr_block = 1
curr_layer = 1
layers = []
for layer in vgg16.features.children():
layers.append(layer)
if curr_block == block_no and curr_layer == layer_within_block:
break
if isinstance(layer, nn.MaxPool2d):
curr_block += 1
curr_layer = 1
else:
curr_layer += 1
self.vgg_loss = nn.Sequential(*layers)
def forward(self, img):
return self.vgg_loss(img)