-
Notifications
You must be signed in to change notification settings - Fork 6
/
Utils.py
88 lines (72 loc) · 3 KB
/
Utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import numpy as np
import nibabel as nib
from Loss import *
import torch
import torch.nn as nn
import torch.nn.functional as F
class SpatialTransformer(nn.Module):
"""
N-D Spatial Transformer
"""
def __init__(self, size, mode='bilinear'):
super().__init__()
self.mode = mode
# create sampling grid
vectors = [torch.arange(0, s) for s in size]
grids = torch.meshgrid(vectors)
grid = torch.stack(grids)
grid = torch.unsqueeze(grid, 0)
grid = grid.type(torch.FloatTensor)
# registering the grid as a buffer cleanly moves it to the GPU, but it also
# adds it to the state dict. this is annoying since everything in the state dict
# is included when saving weights to disk, so the model files are way bigger
# than they need to be. so far, there does not appear to be an elegant solution.
# see: https://discuss.pytorch.org/t/how-to-register-buffer-without-polluting-state-dict
self.register_buffer('grid', grid)
def forward(self, src, flow, return_phi=False):
# new locations
new_locs = self.grid + flow
shape = flow.shape[2:]
# need to normalize grid values to [-1, 1] for resampler
for i in range(len(shape)):
new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5)
# move channels dim to last position
# also not sure why, but the channels need to be reversed
if len(shape) == 2:
new_locs = new_locs.permute(0, 2, 3, 1)
new_locs = new_locs[..., [1, 0]]
elif len(shape) == 3:
new_locs = new_locs.permute(0, 2, 3, 4, 1)
new_locs = new_locs[..., [2, 1, 0]]
if return_phi:
return F.grid_sample(src, new_locs, align_corners=True, mode=self.mode), new_locs
else:
return F.grid_sample(src, new_locs, align_corners=True, mode=self.mode)
def load_nii(path):
X = nib.load(path)
X = X.get_fdata()
return X
def save_nii(img, savename):
affine = np.diag([1, 1, 1, 1])
new_img = nib.nifti1.Nifti1Image(img, affine, header=None)
nib.save(new_img, savename)
def generate_grid3D_tensor(shape):
x_grid = torch.linspace(-1., 1., shape[0])
y_grid = torch.linspace(-1., 1., shape[1])
z_grid = torch.linspace(-1., 1., shape[2])
x_grid, y_grid, z_grid = torch.meshgrid(x_grid, y_grid, z_grid)
# Note that default the dimension in the grid is reversed:
# z, y, x
grid = torch.stack([z_grid, y_grid, x_grid], dim=0)
return grid
def dice(array1, array2, labels):
"""
Computes the dice overlap between two arrays for a given set of integer labels.
"""
dicem = np.zeros(len(labels))
for idx, label in enumerate(labels):
top = 2 * np.sum(np.logical_and(array1 == label, array2 == label))
bottom = np.sum(array1 == label) + np.sum(array2 == label)
bottom = np.maximum(bottom, np.finfo(float).eps) # add epsilon
dicem[idx] = top / bottom
return dicem