-
Notifications
You must be signed in to change notification settings - Fork 0
/
losses.py
29 lines (23 loc) · 1.03 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sun Dec 20 17:16:55 2020
@author: user1
"""
import sys
import torch
import torch.nn as nn
class UncertaintyLoss(nn.Module):
def __init__(self, ctype, copt):
super(UncertaintyLoss, self).__init__()
self.log_vars = nn.Parameter(torch.zeros((2), requires_grad=True, dtype=torch.float32).cuda())
self.std_vars = [torch.exp(log_var) ** 0.5 for log_var in self.log_vars]
self.crit = getattr(nn, ctype)(copt)
def forward(self, net_output, gt_pathology, gt_distortion):
precision1, precision2 = torch.exp(-self.log_vars[0]), torch.exp(-self.log_vars[1])
loss_p = self.crit(net_output['pathology'], gt_pathology)
loss_d = self.crit(net_output['distortion_1'], gt_distortion)
pathology_loss = precision1 * loss_p + self.log_vars[0]
d1_loss = precision2 * loss_d + self.log_vars[1]
loss = torch.mean(pathology_loss + d1_loss)
return loss, pathology_loss, d1_loss, self.log_vars.data.tolist()