-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathbase.py
828 lines (685 loc) · 31.9 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
# Get Python six functionality:
from __future__ import\
absolute_import, print_function, division, unicode_literals
from builtins import zip
import six
###############################################################################
###############################################################################
###############################################################################
import keras.backend as K
import keras.layers
import keras.models
import numpy as np
import warnings
from .. import layers as ilayers
from .. import utils as iutils
from ..utils.keras import checks as kchecks
from ..utils.keras import graph as kgraph
__all__ = [
"NotAnalyzeableModelException",
"AnalyzerBase",
"TrainerMixin",
"OneEpochTrainerMixin",
"AnalyzerNetworkBase",
"ReverseAnalyzerBase"
]
###############################################################################
###############################################################################
###############################################################################
class NotAnalyzeableModelException(Exception):
"""Indicates that the model cannot be analyzed by an analyzer."""
pass
class AnalyzerBase(object):
""" The basic interface of an iNNvestigate analyzer.
This class defines the basic interface for analyzers:
>>> model = create_keras_model()
>>> a = Analyzer(model)
>>> a.fit(X_train) # If analyzer needs training.
>>> analysis = a.analyze(X_test)
>>>
>>> state = a.save()
>>> a_new = A.load(*state)
>>> analysis = a_new.analyze(X_test)
:param model: A Keras model.
:param disable_model_checks: Do not execute model checks that enforce
compatibility of analyzer and model.
.. note:: To develop a new analyzer derive from
:class:`AnalyzerNetworkBase`.
"""
def __init__(self, model, disable_model_checks=False):
self._model = model
self._disable_model_checks = disable_model_checks
self._do_model_checks()
def _add_model_check(self, check, message, check_type="exception"):
if getattr(self, "_model_check_done", False):
raise Exception("Cannot add model check anymore."
" Check was already performed.")
if not hasattr(self, "_model_checks"):
self._model_checks = []
check_instance = {
"check": check,
"message": message,
"type": check_type,
}
self._model_checks.append(check_instance)
def _do_model_checks(self):
model_checks = getattr(self, "_model_checks", [])
if not self._disable_model_checks and len(model_checks) > 0:
check = [x["check"] for x in model_checks]
types = [x["type"] for x in model_checks]
messages = [x["message"] for x in model_checks]
checked = kgraph.model_contains(self._model, check)
tmp = zip(iutils.to_list(checked), messages, types)
for checked_layers, message, check_type in tmp:
if len(checked_layers) > 0:
tmp_message = ("%s\nCheck triggerd by layers: %s" %
(message, checked_layers))
if check_type == "exception":
raise NotAnalyzeableModelException(tmp_message)
elif check_type == "warning":
# TODO(albermax) only the first warning will be shown
warnings.warn(tmp_message)
else:
raise NotImplementedError()
self._model_check_done = True
def fit(self, *args, **kwargs):
"""
Stub that eats arguments. If an analyzer needs training
include :class:`TrainerMixin`.
:param disable_no_training_warning: Do not warn if this function is
called despite no training is needed.
"""
disable_no_training_warning = kwargs.pop("disable_no_training_warning",
False)
if not disable_no_training_warning:
# issue warning if not training is foreseen,
# but is fit is still called.
warnings.warn("This analyzer does not need to be trained."
" Still fit() is called.", RuntimeWarning)
def fit_generator(self, *args, **kwargs):
"""
Stub that eats arguments. If an analyzer needs training
include :class:`TrainerMixin`.
:param disable_no_training_warning: Do not warn if this function is
called despite no training is needed.
"""
disable_no_training_warning = kwargs.pop("disable_no_training_warning",
False)
if not disable_no_training_warning:
# issue warning if not training is foreseen,
# but is fit is still called.
warnings.warn("This analyzer does not need to be trained."
" Still fit_generator() is called.", RuntimeWarning)
def analyze(self, X):
"""
Analyze the behavior of model on input `X`.
:param X: Input as expected by model.
"""
raise NotImplementedError()
def _get_state(self):
state = {
"model_json": self._model.to_json(),
"model_weights": self._model.get_weights(),
"disable_model_checks": self._disable_model_checks,
}
return state
def save(self):
"""
Save state of analyzer, can be passed to :func:`Analyzer.load`
to resemble the analyzer.
:return: The class name and the state.
"""
state = self._get_state()
class_name = self.__class__.__name__
return class_name, state
def save_npz(self, fname):
"""
Save state of analyzer, can be passed to :func:`Analyzer.load_npz`
to resemble the analyzer.
:param fname: The file's name.
"""
class_name, state = self.save()
np.savez(fname, **{"class_name": class_name,
"state": state})
@classmethod
def _state_to_kwargs(clazz, state):
model_json = state.pop("model_json")
model_weights = state.pop("model_weights")
disable_model_checks = state.pop("disable_model_checks")
assert len(state) == 0
model = keras.models.model_from_json(model_json)
model.set_weights(model_weights)
return {"model": model,
"disable_model_checks": disable_model_checks}
@staticmethod
def load(class_name, state):
"""
Resembles an analyzer from the state created by
:func:`analyzer.save()`.
:param class_name: The analyzer's class name.
:param state: The analyzer's state.
"""
# Todo:do in a smarter way!
import innvestigate.analyzer
clazz = getattr(innvestigate.analyzer, class_name)
kwargs = clazz._state_to_kwargs(state)
return clazz(**kwargs)
@staticmethod
def load_npz(fname):
"""
Resembles an analyzer from the file created by
:func:`analyzer.save_npz()`.
:param fname: The file's name.
"""
f = np.load(fname)
class_name = f["class_name"].item()
state = f["state"].item()
return AnalyzerBase.load(class_name, state)
###############################################################################
###############################################################################
###############################################################################
class TrainerMixin(object):
"""Mixin for analyzer that adapt to data.
This convenience interface exposes a Keras like training routing
to the user.
"""
# todo: extend with Y
def fit(self,
X=None,
batch_size=32,
**kwargs):
"""
Takes the same parameters as Keras's :func:`model.fit` function.
"""
generator = iutils.BatchSequence(X, batch_size)
return self._fit_generator(generator,
**kwargs)
def fit_generator(self, *args, **kwargs):
"""
Takes the same parameters as Keras's :func:`model.fit_generator`
function.
"""
return self._fit_generator(*args, **kwargs)
def _fit_generator(self,
generator,
steps_per_epoch=None,
epochs=1,
max_queue_size=10,
workers=1,
use_multiprocessing=False,
verbose=0,
disable_no_training_warning=None):
raise NotImplementedError()
class OneEpochTrainerMixin(TrainerMixin):
"""Exposes the same interface and functionality as :class:`TrainerMixin`
except that the training is limited to one epoch.
"""
def fit(self, *args, **kwargs):
"""
Same interface as :func:`fit` of :class:`TrainerMixin` except that
the parameter epoch is fixed to 1.
"""
return super(OneEpochTrainerMixin, self).fit(*args, epochs=1, **kwargs)
def fit_generator(self, *args, **kwargs):
"""
Same interface as :func:`fit_generator` of :class:`TrainerMixin` except that
the parameter epoch is fixed to 1.
"""
steps = kwargs.pop("steps", None)
return super(OneEpochTrainerMixin, self).fit_generator(
*args,
steps_per_epoch=steps,
epochs=1,
**kwargs)
###############################################################################
###############################################################################
###############################################################################
class AnalyzerNetworkBase(AnalyzerBase):
"""Convenience interface for analyzers.
This class provides helpful functionality to create analyzer's.
Basically it:
* takes the input model and adds a layer that selects
the desired output neuron to analyze.
* passes the new model to :func:`_create_analysis` which should
return the analysis as Keras tensors.
* compiles the function and serves the output to :func:`analyze` calls.
* allows :func:`_create_analysis` to return tensors
that are intercept for debugging purposes.
:param neuron_selection_mode: How to select the neuron to analyze.
Possible values are 'max_activation', 'index' for the neuron
(expects indices at :func:`analyze` calls), 'all' take all neurons.
:param allow_lambda_layers: Allow the model to contain lambda layers.
"""
def __init__(self, model,
neuron_selection_mode="max_activation",
allow_lambda_layers=False,
**kwargs):
if neuron_selection_mode not in ["max_activation", "index", "all"]:
raise ValueError("neuron_selection parameter is not valid.")
self._neuron_selection_mode = neuron_selection_mode
self._allow_lambda_layers = allow_lambda_layers
self._add_model_check(
lambda layer: (not self._allow_lambda_layers and
isinstance(layer, keras.layers.core.Lambda)),
("Lamda layers are not allowed. "
"To force use set allow_lambda_layers parameter."),
check_type="exception",
)
self._special_helper_layers = []
super(AnalyzerNetworkBase, self).__init__(model, **kwargs)
def _add_model_softmax_check(self):
"""
Adds check that prevents models from containing a softmax.
"""
self._add_model_check(
lambda layer: kchecks.contains_activation(
layer, activation="softmax"),
"This analysis method does not support softmax layers.",
check_type="exception",
)
def _prepare_model(self, model):
"""
Prepares the model to analyze before it gets actually analyzed.
This class adds the code to select a specific output neuron.
"""
neuron_selection_mode = self._neuron_selection_mode
model_inputs = model.inputs
model_output = model.outputs
if len(model_output) > 1:
raise ValueError("Only models with one output tensor are allowed.")
analysis_inputs = []
stop_analysis_at_tensors = []
# Flatten to form (batch_size, other_dimensions):
if K.ndim(model_output[0]) > 2:
model_output = keras.layers.Flatten()(model_output)
if neuron_selection_mode == "max_activation":
l = ilayers.Max(name="iNNvestigate_max")
model_output = l(model_output)
self._special_helper_layers.append(l)
elif neuron_selection_mode == "index":
neuron_indexing = keras.layers.Input(
batch_shape=[None, None], dtype=np.int32,
name='iNNvestigate_neuron_indexing')
self._special_helper_layers.append(
neuron_indexing._keras_history[0])
analysis_inputs.append(neuron_indexing)
# The indexing tensor should not be analyzed.
stop_analysis_at_tensors.append(neuron_indexing)
l = ilayers.GatherND(name="iNNvestigate_gather_nd")
model_output = l(model_output+[neuron_indexing])
self._special_helper_layers.append(l)
elif neuron_selection_mode == "all":
pass
else:
raise NotImplementedError()
model = keras.models.Model(inputs=model_inputs+analysis_inputs,
outputs=model_output)
return model, analysis_inputs, stop_analysis_at_tensors
def create_analyzer_model(self):
"""
Creates the analyze functionality. If not called beforehand
it will be called by :func:`analyze`.
"""
model_inputs = self._model.inputs
tmp = self._prepare_model(self._model)
model, analysis_inputs, stop_analysis_at_tensors = tmp
self._analysis_inputs = analysis_inputs
self._prepared_model = model
tmp = self._create_analysis(
model, stop_analysis_at_tensors=stop_analysis_at_tensors)
if isinstance(tmp, tuple):
if len(tmp) == 3:
analysis_outputs, debug_outputs, constant_inputs = tmp
elif len(tmp) == 2:
analysis_outputs, debug_outputs = tmp
constant_inputs = list()
elif len(tmp) == 1:
analysis_outputs = iutils.to_list(tmp[0])
constant_inputs, debug_outputs = list(), list()
else:
raise Exception("Unexpected output from _create_analysis.")
else:
analysis_outputs = tmp
constant_inputs, debug_outputs = list(), list()
analysis_outputs = iutils.to_list(analysis_outputs)
debug_outputs = iutils.to_list(debug_outputs)
constant_inputs = iutils.to_list(constant_inputs)
self._n_data_input = len(model_inputs)
self._n_constant_input = len(constant_inputs)
self._n_data_output = len(analysis_outputs)
self._n_debug_output = len(debug_outputs)
self._analyzer_model = keras.models.Model(
inputs=model_inputs+analysis_inputs+constant_inputs,
outputs=analysis_outputs+debug_outputs)
def _create_analysis(self, model, stop_analysis_at_tensors=[]):
"""
Interface that needs to be implemented by a derived class.
This function is expected to create a Keras graph that creates
a custom analysis for the model inputs given the model outputs.
:param model: Target of analysis.
:param stop_analysis_at_tensors: A list of tensors where to stop the
analysis. Similar to stop_gradient arguments when computing the
gradient of a graph.
:return: Either one-, two- or three-tuple of lists of tensors.
* The first list of tensors represents the analysis for each
model input tensor. Tensors present in stop_analysis_at_tensors
should be omitted.
* The second list, if present, is a list of debug tensors that will
be passed to :func:`_handle_debug_output` after the analysis
is executed.
* The third list, if present, is a list of constant input tensors
added to the analysis model.
"""
raise NotImplementedError()
def _handle_debug_output(self, debug_values):
raise NotImplementedError()
def analyze(self, X, neuron_selection=None):
"""
Same interface as :class:`Analyzer` besides
:param neuron_selection: If neuron_selection_mode is 'index' this
should be an integer with the index for the chosen neuron.
"""
if not hasattr(self, "_analyzer_model"):
self.create_analyzer_model()
X = iutils.to_list(X)
if(neuron_selection is not None and
self._neuron_selection_mode != "index"):
raise ValueError("Only neuron_selection_mode 'index' expects "
"the neuron_selection parameter.")
if(neuron_selection is None and
self._neuron_selection_mode == "index"):
raise ValueError("neuron_selection_mode 'index' expects "
"the neuron_selection parameter.")
if self._neuron_selection_mode == "index":
neuron_selection = np.asarray(neuron_selection).flatten()
if neuron_selection.size == 1:
neuron_selection = np.repeat(neuron_selection, len(X[0]))
# Add first axis indices for gather_nd
neuron_selection = np.hstack(
(np.arange(len(neuron_selection)).reshape((-1, 1)),
neuron_selection.reshape((-1, 1)))
)
ret = self._analyzer_model.predict_on_batch(X+[neuron_selection])
else:
ret = self._analyzer_model(X)#.predict_on_batch(X)
if self._n_debug_output > 0:
self._handle_debug_output(ret[-self._n_debug_output:])
ret = ret[:-self._n_debug_output]
if isinstance(ret, list) and len(ret) == 1:
ret = ret[0]
return ret
def _get_state(self):
state = super(AnalyzerNetworkBase, self)._get_state()
state.update({"neuron_selection_mode": self._neuron_selection_mode})
state.update({"allow_lambda_layers": self._allow_lambda_layers})
return state
@classmethod
def _state_to_kwargs(clazz, state):
neuron_selection_mode = state.pop("neuron_selection_mode")
allow_lambda_layers = state.pop("allow_lambda_layers")
kwargs = super(AnalyzerNetworkBase, clazz)._state_to_kwargs(state)
kwargs.update({
"neuron_selection_mode": neuron_selection_mode,
"allow_lambda_layers": allow_lambda_layers
})
return kwargs
class ReverseAnalyzerBase(AnalyzerNetworkBase):
"""Convenience class for analyzers that revert the model's structure.
This class contains many helper functions around the graph
reverse function :func:`innvestigate.utils.keras.graph.reverse_model`.
The deriving classes should specify how the graph should be reverted
by implementing the following functions:
* :func:`_reverse_mapping(layer)` given a layer this function
returns a reverse mapping for the layer as specified in
:func:`innvestigate.utils.keras.graph.reverse_model` or None.
This function can be implemented, but it is encouraged to
implement a default mapping and add additional changes with
the function :func:`_add_conditional_reverse_mapping` (see below).
The default behavior is finding a conditional mapping (see below),
if none is found, :func:`_default_reverse_mapping` is applied.
* :func:`_default_reverse_mapping` defines the default
reverse mapping.
* :func:`_head_mapping` defines how the outputs of the model
should be instantiated before the are passed to the reversed
network.
Furthermore other parameters of the function
:func:`innvestigate.utils.keras.graph.reverse_model` can
be changed by setting the according parameters of the
init function:
:param reverse_verbose: Print information on the reverse process.
:param reverse_clip_values: Clip the values that are passed along
the reverted network. Expects tuple (min, max).
:param reverse_project_bottleneck_layers: Project the value range
of bottleneck tensors in the reverse network into another range.
:param reverse_check_min_max_values: Print the min/max values
observed in each tensor along the reverse network whenever
:func:`analyze` is called.
:param reverse_check_finite: Check if values passed along the
reverse network are finite.
:param reverse_keep_tensors: Keeps the tensors created in the
backward pass and stores them in the attribute
:attr:`_reversed_tensors`.
:param reverse_reapply_on_copied_layers: See
:func:`innvestigate.utils.keras.graph.reverse_model`.
"""
def __init__(self,
model,
reverse_verbose=False,
reverse_clip_values=False,
reverse_project_bottleneck_layers=False,
reverse_check_min_max_values=False,
reverse_check_finite=False,
reverse_keep_tensors=False,
reverse_reapply_on_copied_layers=False,
**kwargs):
self._reverse_verbose = reverse_verbose
self._reverse_clip_values = reverse_clip_values
self._reverse_project_bottleneck_layers = (
reverse_project_bottleneck_layers)
self._reverse_check_min_max_values = reverse_check_min_max_values
self._reverse_check_finite = reverse_check_finite
self._reverse_keep_tensors = reverse_keep_tensors
self._reverse_reapply_on_copied_layers = (
reverse_reapply_on_copied_layers)
super(ReverseAnalyzerBase, self).__init__(model, **kwargs)
def _gradient_reverse_mapping(self, Xs, Ys, reversed_Ys, reverse_state):
mask = [x not in reverse_state["stop_mapping_at_tensors"] for x in Xs]
return ilayers.GradientWRT(len(Xs), mask=mask)(Xs+Ys+reversed_Ys)
def _reverse_mapping(self, layer):
"""
This function should return a reverse mapping for the passed layer.
If this function returns None, :func:`_default_reverse_mapping`
is applied.
:param layer: The layer for which a mapping should be returned.
:return: The mapping can be of the following forms:
* A function of form (A) f(Xs, Ys, reversed_Ys, reverse_state)
that maps reversed_Ys to reversed_Xs (which should contain
tensors of the same shape and type).
* A function of form f(B) f(layer, reverse_state) that returns
a function of form (A).
* A :class:`ReverseMappingBase` subclass.
"""
if layer in self._special_helper_layers:
# Special layers added by AnalyzerNetworkBase
# that should not be exposed to user.
return self._gradient_reverse_mapping
return self._apply_conditional_reverse_mappings(layer)
def _add_conditional_reverse_mapping(
self, condition, mapping, priority=-1, name=None):
"""
This function should return a reverse mapping for the passed layer.
If this function returns None, :func:`_default_reverse_mapping`
is applied.
:param condition: Condition when this mapping should be applied.
Form: f(layer) -> bool
:param mapping: The mapping can be of the following forms:
* A function of form (A) f(Xs, Ys, reversed_Ys, reverse_state)
that maps reversed_Ys to reversed_Xs (which should contain
tensors of the same shape and type).
* A function of form f(B) f(layer, reverse_state) that returns
a function of form (A).
* A :class:`ReverseMappingBase` subclass.
:param priority: The higher the earlier the condition gets
evaluated.
:param name: An identifying name.
"""
if getattr(self, "_reverse_mapping_applied", False):
raise Exception("Cannot add conditional mapping "
"after first application.")
if not hasattr(self, "_conditional_reverse_mappings"):
self._conditional_reverse_mappings = {}
if priority not in self._conditional_reverse_mappings:
self._conditional_reverse_mappings[priority] = []
tmp = {"condition": condition, "mapping": mapping, "name": name}
self._conditional_reverse_mappings[priority].append(tmp)
def _apply_conditional_reverse_mappings(self, layer):
mappings = getattr(self, "_conditional_reverse_mappings", {})
self._reverse_mapping_applied = True
# Search for mapping. First consider ones with highest priority,
# inside priority in order of adding.
sorted_keys = sorted(mappings.keys())[::-1]
for key in sorted_keys:
for mapping in mappings[key]:
if mapping["condition"](layer):
return mapping["mapping"]
return None
def _default_reverse_mapping(self, Xs, Ys, reversed_Ys, reverse_state):
"""
Fallback function to map reversed_Ys to reversed_Xs
(which should contain tensors of the same shape and type).
"""
return self._gradient_reverse_mapping(
Xs, Ys, reversed_Ys, reverse_state)
def _head_mapping(self, X):
"""
Map output tensors to new values before passing
them into the reverted network.
"""
return X
def _postprocess_analysis(self, X):
return X
def _reverse_model(self,
model,
stop_analysis_at_tensors=[],
return_all_reversed_tensors=False):
return kgraph.reverse_model(
model,
reverse_mappings=self._reverse_mapping,
default_reverse_mapping=self._default_reverse_mapping,
head_mapping=self._head_mapping,
stop_mapping_at_tensors=stop_analysis_at_tensors,
verbose=self._reverse_verbose,
clip_all_reversed_tensors=self._reverse_clip_values,
project_bottleneck_tensors=self._reverse_project_bottleneck_layers,
return_all_reversed_tensors=return_all_reversed_tensors)
def _create_analysis(self, model, stop_analysis_at_tensors=[]):
return_all_reversed_tensors = (
self._reverse_check_min_max_values or
self._reverse_check_finite or
self._reverse_keep_tensors
)
ret = self._reverse_model(
model,
stop_analysis_at_tensors=stop_analysis_at_tensors,
return_all_reversed_tensors=return_all_reversed_tensors)
if return_all_reversed_tensors:
ret = (self._postprocess_analysis(ret[0]), ret[1])
else:
ret = self._postprocess_analysis(ret)
if return_all_reversed_tensors:
debug_tensors = []
self._debug_tensors_indices = {}
values = list(six.itervalues(ret[1]))
mapping = {i: v["id"] for i, v in enumerate(values)}
tensors = [v["final_tensor"] for v in values]
self._reverse_tensors_mapping = mapping
if self._reverse_check_min_max_values:
tmp = [ilayers.Min(None)(x) for x in tensors]
self._debug_tensors_indices["min"] = (
len(debug_tensors),
len(debug_tensors)+len(tmp))
debug_tensors += tmp
tmp = [ilayers.Max(None)(x) for x in tensors]
self._debug_tensors_indices["max"] = (
len(debug_tensors),
len(debug_tensors)+len(tmp))
debug_tensors += tmp
if self._reverse_check_finite:
tmp = iutils.to_list(ilayers.FiniteCheck()(tensors))
self._debug_tensors_indices["finite"] = (
len(debug_tensors),
len(debug_tensors)+len(tmp))
debug_tensors += tmp
if self._reverse_keep_tensors:
self._debug_tensors_indices["keep"] = (
len(debug_tensors),
len(debug_tensors)+len(tensors))
debug_tensors += tensors
ret = (ret[0], debug_tensors)
return ret
def _handle_debug_output(self, debug_values):
if self._reverse_check_min_max_values:
indices = self._debug_tensors_indices["min"]
tmp = debug_values[indices[0]:indices[1]]
tmp = sorted([(self._reverse_tensors_mapping[i], v)
for i, v in enumerate(tmp)])
print("Minimum values in tensors: "
"((NodeID, TensorID), Value) - {}".format(tmp))
indices = self._debug_tensors_indices["max"]
tmp = debug_values[indices[0]:indices[1]]
tmp = sorted([(self._reverse_tensors_mapping[i], v)
for i, v in enumerate(tmp)])
print("Maximum values in tensors: "
"((NodeID, TensorID), Value) - {}".format(tmp))
if self._reverse_check_finite:
indices = self._debug_tensors_indices["finite"]
tmp = debug_values[indices[0]:indices[1]]
nfinite_tensors = np.flatnonzero(np.asarray(tmp) > 0)
if len(nfinite_tensors) > 0:
nfinite_tensors = sorted([self._reverse_tensors_mapping[i]
for i in nfinite_tensors])
print("Not finite values found in following nodes: "
"(NodeID, TensorID) - {}".format(nfinite_tensors))
if self._reverse_keep_tensors:
indices = self._debug_tensors_indices["keep"]
tmp = debug_values[indices[0]:indices[1]]
tmp = sorted([(self._reverse_tensors_mapping[i], v)
for i, v in enumerate(tmp)])
self._reversed_tensors = tmp
def _get_state(self):
state = super(ReverseAnalyzerBase, self)._get_state()
state.update({"reverse_verbose": self._reverse_verbose})
state.update({"reverse_clip_values": self._reverse_clip_values})
state.update({"reverse_project_bottleneck_layers":
self._reverse_project_bottleneck_layers})
state.update({"reverse_check_min_max_values":
self._reverse_check_min_max_values})
state.update({"reverse_check_finite": self._reverse_check_finite})
state.update({"reverse_keep_tensors": self._reverse_keep_tensors})
state.update({"reverse_reapply_on_copied_layers":
self._reverse_reapply_on_copied_layers})
return state
@classmethod
def _state_to_kwargs(clazz, state):
reverse_verbose = state.pop("reverse_verbose")
reverse_clip_values = state.pop("reverse_clip_values")
reverse_project_bottleneck_layers = (
state.pop("reverse_project_bottleneck_layers"))
reverse_check_min_max_values = (
state.pop("reverse_check_min_max_values"))
reverse_check_finite = state.pop("reverse_check_finite")
reverse_keep_tensors = state.pop("reverse_keep_tensors")
reverse_reapply_on_copied_layers = (
state.pop("reverse_reapply_on_copied_layers"))
kwargs = super(ReverseAnalyzerBase, clazz)._state_to_kwargs(state)
kwargs.update({"reverse_verbose": reverse_verbose,
"reverse_clip_values": reverse_clip_values,
"reverse_project_bottleneck_layers":
reverse_project_bottleneck_layers,
"reverse_check_min_max_values":
reverse_check_min_max_values,
"reverse_check_finite": reverse_check_finite,
"reverse_keep_tensors": reverse_keep_tensors,
"reverse_reapply_on_copied_layers":
reverse_reapply_on_copied_layers})
return kwargs