-
Notifications
You must be signed in to change notification settings - Fork 0
/
helpers.py
117 lines (87 loc) · 3.04 KB
/
helpers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
import torch
import torch.nn.functional as F
import nibabel as nib
import matplotlib.pyplot as plt
import numpy as np
import pydicom
from skimage.metrics import adapted_rand_error
from medpy.metric.binary import precision as mp_precision
from medpy.metric.binary import recall as mp_recall
from medpy.metric.binary import dc
def _thresh(img):
img[img > 0.5] = 1
img[img <= 0.5] = 0
return img
def dsc(y_pred, y_true):
y_pred = _thresh(y_pred)
y_true = _thresh(y_true)
return dc(y_pred, y_true)
def dice_m(y_pred, y_true, mean=True):
n_classes = y_pred.shape[0]
targets = F.one_hot(y_true.to(torch.int64), n_classes).squeeze(0)
preds = torch.argmax(y_pred, dim=0)
preds = F.one_hot(preds, n_classes)
dims = tuple(range(0, targets.ndimension() - 1))
tps = torch.sum(preds * targets, dims)
fps = torch.sum(preds * (1 - targets), dims)
fns = torch.sum((1 - preds) * targets, dims)
dice = (2 * tps) / (2 * tps + fps + fns + 1e-5)
if mean:
return dice.mean()
else:
return dice
def iou(y_pred, y_true):
y_pred = _thresh(y_pred)
y_true = _thresh(y_true)
intersection = np.logical_and(y_pred, y_true)
union = np.logical_or(y_pred, y_true)
if not np.any(union):
return 0 if np.any(y_pred) else 1
return intersection.sum() / float(union.sum())
def precision(y_pred, y_true):
y_pred = _thresh(y_pred).astype(np.int)
y_true = _thresh(y_true).astype(np.int)
if y_true.sum() <= 5:
# when the example is nearly empty, avoid division by 0
# if the prediction is also empty, precision is 1
# otherwise it's 0
return 1 if y_pred.sum() <= 5 else 0
if y_pred.sum() <= 5:
return 0.
return mp_precision(y_pred, y_true)
def recall(y_pred, y_true):
y_pred = _thresh(y_pred).astype(np.int)
y_true = _thresh(y_true).astype(np.int)
if y_true.sum() <= 5:
# when the example is nearly empty, avoid division by 0
# if the prediction is also empty, recall is 1
# otherwise it's 0
return 1 if y_pred.sum() <= 5 else 0
if y_pred.sum() <= 5:
return 0.
r = mp_recall(y_pred, y_true)
return r
def listdir(path):
""" List files but remove hidden files from list """
return [item for item in os.listdir(path) if item[0] != '.']
def mkdir(path):
if not os.path.exists(path):
os.makedirs(path)
def show_images_row(imgs, titles=None, rows=1, figsize=(6.4, 4.8), **kwargs):
'''
Display grid of cv2 images
:param img: list [cv::mat]
:param title: titles
:return: None
'''
assert ((titles is None) or (len(imgs) == len(titles)))
num_images = len(imgs)
if titles is None:
titles = ['Image (%d)' % i for i in range(1, num_images + 1)]
fig = plt.figure(figsize=figsize)
for n, (image, title) in enumerate(zip(imgs, titles)):
ax = fig.add_subplot(rows, np.ceil(num_images / float(rows)), n + 1)
plt.imshow(image, **kwargs)
ax.set_title(title)
plt.axis('off')