-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathloss_dic.py
54 lines (41 loc) · 1.65 KB
/
loss_dic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
from segmentation_models_pytorch.losses import DiceLoss
from torch.nn import BCEWithLogitsLoss
def get_loss_function_by_name(name, from_logits=True):
if name == "binary_cross_entropy_with_logits":
return BCEWithLogitsLoss(reduction='none')
if name == "dice_loss":
def loss(y_hat, y, ignore_index=None):
y_hat = torch.unsqueeze(y_hat, 1)
return DiceLoss('binary', from_logits=from_logits, ignore_index=ignore_index)(y_hat, y)
return loss
if name == "focal_dice_loss":
def dice_loss(input, target):
smooth = 1e-7
if from_logits:
input = torch.sigmoid(input)
iflat = input.view(-1)
iflat2 = iflat*iflat
iflat_in = 1 - iflat
iflat = iflat2/(iflat2 + iflat_in*iflat_in)
tflat = target.view(-1)
intersection = (iflat * tflat).sum()
return 1 - ((2. * intersection + smooth) / (iflat.sum() + tflat.sum() + smooth))
return dice_loss
raise ValueError("Loss function {} not found".format(name))
def loss_support_mask(name):
if name == "binary_cross_entropy_with_logits":
return True
if name == "dice_loss":
return False
if name == "focal_dice_loss":
return False
raise ValueError("Loss function {} not found".format(name))
def loss_support_ignore_index(name):
if name == "binary_cross_entropy_with_logits":
return False
if name == "dice_loss":
return True
if name == "focal_dice_loss":
return False
raise ValueError("Loss function {} not found".format(name))