forked from akarasman/yolo-heatmaps
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinverter_util.py
195 lines (131 loc) · 4.76 KB
/
inverter_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
from numpy import iterable
import torch
import torch.nn.functional as F
from lrp.utils import pprint, flexible_prop
def winner_takes_all(relevance_in : torch.Tensor, in_shape : iterable, indices : torch.Tensor ) -> torch.Tensor :
"""
Implements winner takes-all scheme for re-distibution of relevance
for a max pooling layer
Arguments
---------
relevance_in : torch.Tensor
Incoming relevance from upper layers.
in_shape : list or tuple
Shape of module input.
indices : torch.Tensor
Indexes of selected (max) features.
Returns
-------
relevance_out : torch.Tensor
Relevance redistributed to lower layer.
"""
# (REAL SLOW, MAKE THIS FASTER !)
B, C, H, W = relevance_in.size()
N = H * W
relevance_out = []
for rin in relevance_in :
rout = torch.zeros(in_shape).flatten()
relevance_flat = rin.flatten()
for i, idx in enumerate(indices.flatten()):
rout[idx + (i // N) * N] += relevance_flat[i]
relevance_out.append(rout.view(in_shape))
return torch.cat(relevance_out, dim=0)
def conv_nd_fwd_hook(m, in_tensor, out_tensor):
""" Default n-dimensional convolution forward hook """
setattr(m, "in_tensor", in_tensor[0])
setattr(m, "out_tensor", out_tensor)
def max_pool_nd_fwd_hook(m, in_tensor, out_tensor):
""" Default n-dimensional max pool forward hook """
cache = m.return_indices
_, indices = F.max_pool2d(in_tensor[0], kernel_size=m.kernel_size, stride=m.stride, padding=m.padding,
dilation=m.dilation, return_indices=True, ceil_mode=m.ceil_mode)
setattr(m, "indices", indices)
setattr(m, 'out_shape', out_tensor.size())
setattr(m, 'in_shape', in_tensor[0].size())
def upsample_fwd_hook(m, in_tensor, out_tensor):
""" Default up-sampling forward hook """
setattr(m, 'in_dim', len(in_tensor[0].shape))
setattr(m, 'out_shape', out_tensor.shape)
def linear_fwd_hook(m, in_tensor, out_tensor):
""" Default Linear layer forward hook """
setattr(m, "in_tensor", in_tensor[0])
setattr(m, "out_shape", list(out_tensor.size()))
def silent_pass(m, in_tensor, out_tensor):
""" <<Silent>> forward hook that saves nothing """
pass
def LogSoftmax_inverse(relevance : torch.Tensor, warn : bool = True) -> torch.Tensor :
"""
Inversion of LogSoftmax layer
Arguments
---------
relevance : torch.Tensor
Input relavance
warn : bool
Display warning message when applied
Returns
-------
torch.Tensor
Output relevance
"""
if relevance.sum() < 0:
relevance[relevance == 0] = -1e6
relevance = relevance.exp()
if warn :
pprint("WARNING: LogSoftmax layer was "
"turned into probabilities.")
return relevance
@flexible_prop
def max_pool_nd_inverse(layer, relevance_in : torch.Tensor, indices : torch.Tensor = None,
max : bool = False) -> torch.Tensor :
"""
Inversion of LogSoftmax layer
Arguments
---------
relevance : torch.Tensor
Input relavance
indices : torch.Tensor
Maximum feature indexes obtained when max pooling
max : bool
Implement winner takes all scheme in relevance re-distribution
Returns
-------
torch.Tensor
Output relevance
"""
if indices is None :
indices = layer.indices
out_shape = layer.out_shape
bs = relevance_in.size(0)
relevance_in = torch.cat([r.view(out_shape) for r in relevance_in ], dim=0)
indices = torch.cat([indices] * bs, dim=0)
return ( winner_takes_all(relevance_in, layer.in_shape, layer.indices)
if max else relevance_in )
@flexible_prop
def upsample_inverse(layer, relevance : torch.Tensor) -> torch.Tensor :
"""
Inversion of upsample layer
Arguments
---------
relevance : torch.Tensor
Input relavance
Returns
-------
torch.Tensor
Output relevance
ATTENTION : Currently only 'nearest' upsampling method is invertable
"""
invert_upsample = {
1 : F.avg_pool1d,
2 : F.avg_pool2d,
3 : F.avg_pool3d
} [layer.in_dim - 2]
if layer.mode != 'nearest' :
raise NotImplementedError("Upsample layer must be in 'nearest' mode ")
relevance_in = torch.cat([r.view(layer.out_shape) for r in relevance], dim=0)
if isinstance(layer.scale_factor, float):
ks = int(layer.scale_factor)
elif isinstance(layer.scale_factor, tuple):
ks = tuple([ int(s) for s in layer.scale_factor ])
inverted = invert_upsample(relevance_in, kernel_size=ks, stride=ks)
inverted *= ks**2 # Normalizing constant
return inverted