-
Notifications
You must be signed in to change notification settings - Fork 0
/
MultiSSimFocaldicePixelClassificationLayer.m
62 lines (45 loc) · 1.84 KB
/
MultiSSimFocaldicePixelClassificationLayer.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
classdef MultiSSimFocaldicePixelClassificationLayer < nnet.layer.ClassificationLayer
% This layer implements the focaldice loss function for training semantic
% segmentation networks.
properties(Constant)
end
properties
Gamma= 4/3;
end
methods
function layer = MultiSSimFocaldicePixelClassificationLayer(name, gamma)
% layer = focaldicePixelClassificationLayer(name) creates a
% focaldice pixel classification layer with the specified name.
% Set layer name.
layer.Name = name;
% Set layer properties.
layer.Gamma = gamma;
% Set layer description.
layer.Description = 'MultiSSim + focaldice loss';
end
function loss = forwardLoss(layer, Y, T)
% loss = forwardLoss(layer, Y, T) returns the focaldice loss
% between the predictions Y and the training targets T.
T=dlarray(T);
for Nimg=1:size(Y,4)
P1=(T(:,:,1,Nimg));
P2=(Y(:,:,1,Nimg));
DistMS(Nimg)=1-multissim(dlarray(P1),P2);
end
DistMS=mean(DistMS);
TP = sum(sum(Y.*T,1),2);
a=(sum(sum(T.*T,1),2)).^2;
w=1./a;
numer =2*sum(w.*TP,3);
denom = sum(sum(Y.*Y+T.*T,1),2);
denom=sum(w.*denom,3);
% Compute focaldice index
lossTIc = (1 - numer./denom).^(1./layer.Gamma);
lossTI = sum(lossTIc,3);
% Return average focaldice index loss.
N = size(Y,4);
loss = sum(lossTI)/N;
loss=loss+DistMS;
end
end
end