-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathDenseEnergyLoss.py
76 lines (61 loc) · 3.1 KB
/
DenseEnergyLoss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
import sys
sys.path.append("./wrapper/bilateralfilter/build/lib.linux-x86_64-3.6")
from bilateralfilter import bilateralfilter, bilateralfilter_batch
class DenseEnergyLossFunction(Function):
@staticmethod
def forward(ctx, images, segmentations, sigma_rgb, sigma_xy, ROIs, unlabel_region):
ctx.save_for_backward(segmentations)
ctx.N, ctx.K, ctx.H, ctx.W = segmentations.shape
Gate = ROIs.cuda()
ROIs = ROIs.unsqueeze_(1).repeat(1,ctx.K,1,1)
seg_max = torch.max(segmentations, dim=1)[0]
Gate = Gate - seg_max
Gate[unlabel_region] = 1
Gate[Gate < 0] = 0
Gate = Gate.unsqueeze_(1).repeat(1, ctx.K, 1, 1)
segmentations = torch.mul(segmentations.cuda(), ROIs.cuda())
ctx.ROIs = ROIs
densecrf_loss = 0.0
images = images.numpy().flatten()
segmentations = segmentations.cpu().numpy().flatten()
AS = np.zeros(segmentations.shape, dtype=np.float32)
bilateralfilter_batch(images, segmentations, AS, ctx.N, ctx.K, ctx.H, ctx.W, sigma_rgb, sigma_xy)
Gate = Gate.cpu().numpy().flatten()
AS = np.multiply(AS, Gate)
densecrf_loss -= np.dot(segmentations, AS)
# averaged by the number of images
densecrf_loss /= ctx.N
ctx.AS = np.reshape(AS, (ctx.N, ctx.K, ctx.H, ctx.W))
return Variable(torch.tensor([densecrf_loss]), requires_grad=True)
@staticmethod
def backward(ctx, grad_output):
grad_segmentation = -2*grad_output*torch.from_numpy(ctx.AS)/ctx.N
grad_segmentation = grad_segmentation.cuda()
grad_segmentation = torch.mul(grad_segmentation, ctx.ROIs.cuda())
return None, grad_segmentation, None, None, None, None
class DenseEnergyLoss(nn.Module):
def __init__(self, weight, sigma_rgb, sigma_xy, scale_factor):
super(DenseEnergyLoss, self).__init__()
self.weight = weight
self.sigma_rgb = sigma_rgb
self.sigma_xy = sigma_xy
self.scale_factor = scale_factor
def forward(self, images, segmentations, ROIs, seg_label):
""" scale imag by scale_factor """
scaled_images = F.interpolate(images,scale_factor=self.scale_factor)
scaled_segs = F.interpolate(segmentations,scale_factor=self.scale_factor,mode='bilinear',align_corners=False)
scaled_ROIs = F.interpolate(ROIs.unsqueeze(1),scale_factor=self.scale_factor).squeeze(1)
scaled_seg_label = F.interpolate(seg_label,scale_factor=self.scale_factor,mode='nearest')
unlabel_region = (scaled_seg_label.long() == 255).squeeze(1)
return self.weight*DenseEnergyLossFunction.apply(
scaled_images, scaled_segs, self.sigma_rgb, self.sigma_xy*self.scale_factor, scaled_ROIs, unlabel_region)
def extra_repr(self):
return 'sigma_rgb={}, sigma_xy={}, weight={}, scale_factor={}'.format(
self.sigma_rgb, self.sigma_xy, self.weight, self.scale_factor
)