forked from baidu-research/NCRF
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlayers.py
59 lines (50 loc) · 2.27 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
from torch import nn
class CRF(nn.Module):
def __init__(self, num_nodes, iteration=10):
"""Initialize the CRF module
Args:
num_nodes: int, number of nodes/patches within the fully CRF
iteration: int, number of mean field iterations, e.g. 10
"""
super(CRF, self).__init__()
self.num_nodes = num_nodes
self.iteration = iteration
self.W = nn.Parameter(torch.zeros(1, num_nodes, num_nodes))
def forward(self, feats, logits):
"""Performing the CRF.
Args:
feats: 3D tensor with the shape of
[batch_size, num_nodes, embedding_size], where num_nodes is the
number of patches within a grid, e.g. 9 for a 3x3 grid;
embedding_size is the size of extracted feature representation for
each patch from ResNet, e.g. 512
logits: 3D tensor with shape of [batch_size, num_nodes, 1], the
logit of each patch within the grid being tumor before CRF
Returns:
logits: 3D tensor with shape of [batch_size, num_nodes, 1], the
logit of each patch within the grid being tumor after CRF
"""
feats_norm = torch.norm(feats, p=2, dim=2, keepdim=True)
pairwise_norm = torch.bmm(feats_norm,
torch.transpose(feats_norm, 1, 2))
pairwise_dot = torch.bmm(feats, torch.transpose(feats, 1, 2))
# cosine similarity between feats
pairwise_sim = pairwise_dot / pairwise_norm
# symmetric constraint for CRF weights
W_sym = (self.W + torch.transpose(self.W, 1, 2)) / 2
pairwise_potential = pairwise_sim * W_sym
unary_potential = logits.clone()
for i in range(self.iteration):
# current Q
probs = torch.transpose(logits.sigmoid(), 1, 2)
# taking expectation of pairwise_potential using current Q
pairwise_potential_E = torch.sum(
probs * pairwise_potential - (1 - probs) * pairwise_potential,
dim=2, keepdim=True)
logits = unary_potential + pairwise_potential_E
return logits
def __repr__(self):
return 'CRF(num_nodes={}, iteration={})'.format(
self.num_nodes, self.iteration
)