diff --git a/keras/engine/training.py b/keras/engine/training.py index 78be752b009..e78477e4341 100644 --- a/keras/engine/training.py +++ b/keras/engine/training.py @@ -853,6 +853,7 @@ def compile(self, optimizer, loss=None, metrics=None, loss_weights=None, nested_weighted_metrics = _collect_metrics(weighted_metrics, self.output_names) self.metrics_updates = [] self.stateful_metric_names = [] + self.stateful_metric_functions = [] with K.name_scope('metrics'): for i in range(len(self.outputs)): if i in skip_target_indices: @@ -929,6 +930,7 @@ def handle_metrics(metrics, weights=None): # stateful metrics (i.e. metrics layers). if isinstance(metric_fn, Layer) and metric_fn.stateful: self.stateful_metric_names.append(metric_name) + self.stateful_metric_functions.append(metric_fn) self.metrics_updates += metric_fn.updates handle_metrics(output_metrics) @@ -1174,9 +1176,8 @@ def _fit_loop(self, f, ins, out_labels=None, batch_size=None, for epoch in range(initial_epoch, epochs): # Reset stateful metrics - for m in self.metrics: - if isinstance(m, Layer) and m.stateful: - m.reset_states() + for m in self.stateful_metric_functions: + m.reset_states() callbacks.on_epoch_begin(epoch) epoch_logs = {} if steps_per_epoch is not None: @@ -1363,9 +1364,8 @@ def _test_loop(self, f, ins, batch_size=None, verbose=0, steps=None): """ if hasattr(self, 'metrics'): - for m in self.metrics: - if isinstance(m, Layer) and m.stateful: - m.reset_states() + for m in self.stateful_metric_functions: + m.reset_states() stateful_metric_indices = [ i for i, name in enumerate(self.metrics_names) if str(name) in self.stateful_metric_names] @@ -2185,9 +2185,8 @@ def generate_arrays_from_file(path): # Construct epoch logs. epoch_logs = {} while epoch < epochs: - for m in self.metrics: - if isinstance(m, Layer) and m.stateful: - m.reset_states() + for m in self.stateful_metric_functions: + m.reset_states() callbacks.on_epoch_begin(epoch) steps_done = 0 batch_index = 0 @@ -2331,9 +2330,8 @@ def evaluate_generator(self, generator, steps=None, stateful_metric_indices = [] if hasattr(self, 'metrics'): - for i, m in enumerate(self.metrics): - if isinstance(m, Layer) and m.stateful: - m.reset_states() + for m in self.stateful_metric_functions: + m.reset_states() stateful_metric_indices = [ i for i, name in enumerate(self.metrics_names) if str(name) in self.stateful_metric_names] diff --git a/tests/keras/metrics_test.py b/tests/keras/metrics_test.py index 2e40daeda3d..39ad1cdbc5a 100644 --- a/tests/keras/metrics_test.py +++ b/tests/keras/metrics_test.py @@ -107,7 +107,8 @@ def test_sparse_top_k_categorical_accuracy(): @keras_test -def test_stateful_metrics(): +@pytest.mark.parametrize('metrics_mode', ['list', 'dict']) +def test_stateful_metrics(metrics_mode): np.random.seed(1334) class BinaryTruePositives(keras.layers.Layer): @@ -155,11 +156,17 @@ def __call__(self, y_true, y_pred): # Test on simple model inputs = keras.Input(shape=(2,)) - outputs = keras.layers.Dense(1, activation='sigmoid')(inputs) + outputs = keras.layers.Dense(1, activation='sigmoid', name='out')(inputs) model = keras.Model(inputs, outputs) - model.compile(optimizer='sgd', - loss='binary_crossentropy', - metrics=['acc', metric_fn]) + + if metrics_mode == 'list': + model.compile(optimizer='sgd', + loss='binary_crossentropy', + metrics=['acc', metric_fn]) + elif metrics_mode == 'dict': + model.compile(optimizer='sgd', + loss='binary_crossentropy', + metrics={'out': ['acc', metric_fn]}) samples = 1000 x = np.random.random((samples, 2))