forked from Mars-Wei/JNPL
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathloss.py
105 lines (80 loc) · 5.23 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import tensorflow as tf
from tensorflow import keras
import random
import numpy as np
"""
this py define the loss function that compute the NL and PL
cp_label means the complementary label(random generated)
"""
class CustomLoss(keras.losses.Loss):
def __init__(self, class_num, param=0.0, name="custom_loss"):
super().__init__(name=name)
self.class_num = class_num
self.th = 1/self.class_num
self.param = param
self.l = [-100000000000000.]
for i in range(self.class_num-1):
self.l.append(100000000000000.)
def call(self, y_pred, y_true2):
self.batch_size = tf.shape(y_pred)[0]
self.cp_label = y_true2[self.batch_size:]
y_true = y_true2[:self.batch_size]
y_true = tf.reshape(y_true, [self.batch_size, 1])
self.cp_label = tf.reshape(self.cp_label, [self.batch_size, 1])
self.cp_label_onehot = tf.reshape(tf.one_hot(self.cp_label, axis=1, depth=self.class_num), [self.batch_size, self.class_num]) #shape = (batch_size, class_nums)
self.y_true_onehot = tf.reshape(tf.one_hot(y_true, axis=1, depth=self.class_num), [self.batch_size, self.class_num]) #shape = (batch_size, class_nums)
self.predict_label = tf.reshape(tf.argmax(y_pred, axis=1),
[self.batch_size, 1]) #shape = (batch_size, 1)
NL_score = self.NL(y_pred, self.cp_label)
PL_score = self.PL(y_pred, self.predict_label)
score = NL_score + self.param * PL_score
#max = tf.constant([5],dtype='float32')
#out = tf.where(score < 5, x=score, y=)
return score
def NL(self, y_pred, cp_label):
# build a index to gather the socre from the socre matrix
index = tf.cast(tf.reshape(tf.linspace(0, self.batch_size-1,
self.batch_size),
[self.batch_size, 1]),
dtype='int64') #shape = (batch_size, 1)
index = tf.concat((index, cp_label), axis=1) # shape = (batch_size, 2)
py = tf.gather_nd(y_pred, index) # shape = (1,batch_size)
# calculate the NL cross_entropy between the cp_label and the predict score
cp_label_onehot = tf.reshape(tf.one_hot(cp_label, axis=-1, depth=self.class_num), #shape = (batch_size, class_nums)
[self.batch_size, self.class_num])
cross_entropy = cp_label_onehot * tf.math.log(tf.clip_by_value(1-y_pred,1e-3,1.0)) #shape = (batch_size, class_nums)
cross_entropy = tf.reduce_sum(cross_entropy, axis=1) #shape = (batch_size, 1)
weight = -(1 - py) #shape = (1, batch_size)
out = tf.matmul(tf.reshape(weight, [1, self.batch_size]),
tf.reshape(cross_entropy, [self.batch_size, 1])) #shape = (1,1)
out = -tf.reduce_sum(cross_entropy,axis=-1)
return tf.reshape(out, [1]) / tf.cast(self.batch_size,dtype='float32') #shape = (1, )
def PL(self, y_pred, pred_label):
# select the predict score that satisfied the th
one = tf.ones_like(y_pred)
zero = tf.zeros_like(y_pred)
label = tf.where(y_pred < self.th, x=zero, y=one) #shape = (batch_size, class_num)
label = tf.reduce_sum(label, axis=-1) #shape = (batch_size)
one = tf.ones_like(label)
zero = tf.zeros_like(label)
label = tf.where(label < 2, x=one, y=zero)
D = y_pred * tf.reshape(label, [self.batch_size, 1]) #shape = (batch_size, class_num)
# calculate the PL
num = self.batch_size
index = tf.cast(tf.reshape(tf.linspace(0, num - 1,
num),
[num, 1]),
dtype='int64') # shape = (n, 1)
D_label = tf.reshape(tf.argmax(D, axis=1),
[num, 1]) # shape = (n, 1)
index = tf.concat((index, D_label), axis=1) # shape = (n, 2)
py = tf.gather_nd(D, index) # shape = (1, n)
py = 1 - tf.math.square(py) # shape = (1, n)
py = tf.reshape(py, [1,num]) # shape = (n)
weight = tf.reduce_prod(py)
one_hot = self.y_true_onehot * tf.reshape(label, [self.batch_size, 1])
cross_entropy = one_hot * tf.math.log(y_pred) # shape = (batch_size, class_nums)
cross_entropy = tf.reduce_sum(cross_entropy, axis=1) # shape = (batch_size, 1)
out = -weight * cross_entropy
out = tf.reduce_sum(out, axis=0)
return out