-
Notifications
You must be signed in to change notification settings - Fork 5
/
projection_layer_CSPN.py
92 lines (74 loc) · 2.86 KB
/
projection_layer_CSPN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# the simplex projection algorithm implemented as a layer, while using the saliency maps to obtain object size estimates
import sys
sys.path.insert(0,'/home/briq/libs/caffe/python')
import caffe
import random
import numpy as np
import scipy.misc
import imageio
import cv2
import scipy.ndimage as nd
import os.path
import scipy.io as sio
class SimplexProjectionLayer(caffe.Layer):
saliency_path = '/media/VOC/saliency/thresholded_saliency_images/'
input_list_path = '/home/briq/libs/CSPN/training/input_list.txt'
def simplexProjectionLinear(self, data_ind, class_ind, V_im, nu):
if(nu<1):
return V_im
heatmap_size = V_im.shape[0]*V_im.shape[1]
theta = np.sum(V_im)
if(theta ==nu): # the size constrain is already satisfied
return V_im
if(theta < nu):
pi = V_im+(nu-theta)/heatmap_size
return pi
V = V_im.flatten()
s = 0.0
p = 0.0
U=V
while(len(U) > 0):
k = random.randint(0, len(U)-1)
uk = U[k]
UG = U[U>=uk]
delta_p = len(UG)
delta_s = np.sum(UG)
if ((s+delta_s)-(p+delta_p)*uk<nu):
s = s+delta_s
p = p+delta_p
U = U[U<uk]
else:
U = UG
U = np.delete(U, np.where(U==uk))
if(p<0.000001):
raise ValueError('rho is too small, apparently something went wrong in the CNN') # happens when nu<1 or V_im=infinity for example
theta = (s-nu)/p
pi = V_im-theta
return pi
def setup(self, bottom, top):
self.num_labels = bottom[0].shape[1]
with open(self.input_list_path) as fp:
self.images = fp.readlines()
random.seed()
def reshape(self, bottom, top):
top[0].reshape(*bottom[0].data.shape)
def forward(self, bottom, top):
for i in range(bottom[0].num):
im_id = int(bottom[2].data[i])
im_name = self.images[im_id].split(' ')[0].split('.')[0]
top[0].data[i] = bottom[0].data[i]
saliency_name = self.saliency_path+im_name+'.mat'
if (not os.path.isfile(saliency_name)):
continue
saliency_im = sio.loadmat(saliency_name, squeeze_me=True)['data']
for c in range(self.num_labels):
if(c==0):
continue
if(bottom[1].data[i,0,0,c]>0.5): # the label is there
instance = bottom[0].data[i][c]
nu = np.sum(saliency_im==c)
if(nu>1):
instance = bottom[0].data[i][c]
top[0].data[i][c]= self.simplexProjectionLinear(i, c, instance, nu)
def backward(self, top, propagate_down, bottom):
pass