-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathlosses.py
28 lines (24 loc) · 874 Bytes
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
from torch.nn.modules.loss import _Loss
def completion_network_loss(input, output, mask):
bs = input.size(0)
loss = torch.sum(torch.abs(output * mask - input * mask)) / bs
#return mse_loss(output * mask, input * mask)
return loss
def noise_loss(V, img1, img2):
feat1 = V(img1)[0]
feat2 = V(img2)[0]
loss = torch.mean(torch.abs(feat1 - feat2))
return loss
class ContextLoss(_Loss):
def forward(self, mask, gen, images):
bs = gen.size(0)
context_loss = torch.sum(torch.abs(torch.mul(mask, gen) - torch.mul(mask, images))) / bs
return context_loss
class CrossEntropyLoss(_Loss):
def forward(self, out, gt):
bs = out.size(0)
#print(out.size(), gt.size())
loss = - torch.mul(gt.float(), torch.log(out.float() + 1e-7))
loss = torch.sum(loss) / bs
return loss