-
Notifications
You must be signed in to change notification settings - Fork 80
/
loss.py
44 lines (33 loc) · 1.4 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from train_util import *
'''
Loss function
author:Fengting Yang
Mar.1st 2019
We only use "compute_semantic_pos_loss" func. in our final version, best result achieved with weight = 3e-3
'''
def compute_semantic_pos_loss(prob_in, labxy_feat, pos_weight = 0.003, kernel_size=16):
# this wrt the slic paper who used sqrt of (mse)
# rgbxy1_feat: B*50+2*H*W
# output : B*9*H*w
# NOTE: this loss is only designed for one level structure
# todo: currently we assume the downsize scale in x,y direction are always same
S = kernel_size
m = pos_weight
prob = prob_in.clone()
b, c, h, w = labxy_feat.shape
pooled_labxy = poolfeat(labxy_feat, prob, kernel_size, kernel_size)
reconstr_feat = upfeat(pooled_labxy, prob, kernel_size, kernel_size)
loss_map = reconstr_feat[:,-2:,:,:] - labxy_feat[:,-2:,:,:]
# self def cross entropy -- the official one combined softmax
logit = torch.log(reconstr_feat[:, :-2, :, :] + 1e-8)
loss_sem = - torch.sum(logit * labxy_feat[:, :-2, :, :]) / b
loss_pos = torch.norm(loss_map, p=2, dim=1).sum() / b * m / S
# empirically we find timing 0.005 tend to better performance
loss_sum = 0.005 * (loss_sem + loss_pos)
loss_sem_sum = 0.005 * loss_sem
loss_pos_sum = 0.005 * loss_pos
return loss_sum, loss_sem_sum, loss_pos_sum