-
Notifications
You must be signed in to change notification settings - Fork 0
/
mindssc.py
49 lines (38 loc) · 2.07 KB
/
mindssc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import *
def mindssc(img, delta=1, sigma=0.8):
# see http://mpheinrich.de/pub/miccai2013_943_mheinrich.pdf for details on the MIND-SSC descriptor
device = img.device
# define start and end locations for self-similarity pattern
six_neighbourhood = torch.tensor([[0, 1, 1],
[1, 1, 0],
[1, 0, 1],
[1, 1, 2],
[2, 1, 1],
[1, 2, 1]], dtype=torch.float, device=device)
# squared distances
dist = pdist(six_neighbourhood.unsqueeze(0)).squeeze(0)
# define comparison mask
x, y = torch.meshgrid(torch.arange(6, device=device), torch.arange(6, device=device))
mask = ((x > y).view(-1) & (dist == 2).view(-1))
# build kernel
idx_shift1 = six_neighbourhood.unsqueeze(1).repeat(1,6,1).view(-1,3)[mask, :].long()
idx_shift2 = six_neighbourhood.unsqueeze(0).repeat(6,1,1).view(-1,3)[mask, :].long()
mshift1 = torch.zeros((12, 1, 3, 3, 3), device=device)
mshift1.view(-1)[torch.arange(12, device=device) * 27 + idx_shift1[:,0] * 9 + idx_shift1[:, 1] * 3 + idx_shift1[:, 2]] = 1
mshift2 = torch.zeros((12, 1, 3, 3, 3), device=device)
mshift2.view(-1)[torch.arange(12, device=device) * 27 + idx_shift2[:,0] * 9 + idx_shift2[:, 1] * 3 + idx_shift2[:, 2]] = 1
rpad = nn.ReplicationPad3d(delta)
# compute patch-ssd
ssd = smooth(((F.conv3d(rpad(img), mshift1, dilation=delta) - F.conv3d(rpad(img), mshift2, dilation=delta)) ** 2), sigma)
# MIND equation
mind = ssd - torch.min(ssd, 1, keepdim=True)[0]
mind_var = torch.mean(mind, 1, keepdim=True)
mind_var = torch.clamp(mind_var, mind_var.mean() * 0.001, mind_var.mean() * 1000)
mind /= mind_var
mind = torch.exp(-mind)
#permute to have same ordering as C++ code
mind = mind[:, torch.tensor([6, 8, 1, 11, 2, 10, 0, 7, 9, 4, 5, 3], dtype=torch.long), :, :, :]
return mind