Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

General stateful metrics fixes #9446

Merged
merged 10 commits into from
Mar 22, 2018
45 changes: 31 additions & 14 deletions keras/engine/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,7 +927,7 @@ def handle_metrics(metrics, weights=None):

# Keep track of state updates created by
# stateful metrics (i.e. metrics layers).
if isinstance(metric_fn, Layer):
if isinstance(metric_fn, Layer) and metric_fn.stateful:
self.stateful_metric_names.append(metric_name)
self.metrics_updates += metric_fn.updates

Expand Down Expand Up @@ -1175,7 +1175,7 @@ def _fit_loop(self, f, ins, out_labels=None, batch_size=None,
for epoch in range(initial_epoch, epochs):
# Reset stateful metrics
for m in self.metrics:
if isinstance(m, Layer):
if isinstance(m, Layer) and m.stateful:
m.reset_states()
callbacks.on_epoch_begin(epoch)
epoch_logs = {}
Expand Down Expand Up @@ -1364,7 +1364,7 @@ def _test_loop(self, f, ins, batch_size=None, verbose=0, steps=None):

if hasattr(self, 'metrics'):
for m in self.metrics:
if isinstance(m, Layer):
if isinstance(m, Layer) and m.stateful:
m.reset_states()
stateful_metric_indices = [
i for i, name in enumerate(self.metrics_names)
Expand Down Expand Up @@ -1398,7 +1398,7 @@ def _test_loop(self, f, ins, batch_size=None, verbose=0, steps=None):
outs.append(0.)
for i, batch_out in enumerate(batch_outs):
if i in stateful_metric_indices:
outs[i] = batch_out
outs[i] = float(batch_out)
else:
outs[i] += batch_out
else:
Expand Down Expand Up @@ -2198,6 +2198,9 @@ def generate_arrays_from_file(path):
# Construct epoch logs.
epoch_logs = {}
while epoch < epochs:
for m in self.metrics:
if isinstance(m, Layer) and m.stateful:
m.reset_states()
callbacks.on_epoch_begin(epoch)
steps_done = 0
batch_index = 0
Expand Down Expand Up @@ -2331,9 +2334,20 @@ def evaluate_generator(self, generator, steps=None,
"""
self._make_test_function()

stateful_metric_indices = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This reset is needed to make Stateful Metrics work for generators.

How do you feel about spinning this bug fix out into a separate PR? Should be a quick approval.

Some of the other changes, for instance m.stateful will likely have some discussion. I really want this bug fix to make it into the next release :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mind if I raise the reset in evaluate generator to get it through?

Don't want to steal your thunder.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worries about that! But I'm going to take better care of the MR from now on and update it in a couple of minutes, sorry about the delays!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like you got it, I just wanted to make sure this PR wasn't going to get stuck for another couple of releases.

if hasattr(self, 'metrics'):
for i, m in enumerate(self.metrics):
if isinstance(m, Layer) and m.stateful:
m.reset_states()
stateful_metric_indices = [
i for i, name in enumerate(self.metrics_names)
if str(name) in self.stateful_metric_names]
else:
stateful_metric_indices = []

steps_done = 0
wait_time = 0.01
all_outs = []
outs_per_batch = []
batch_sizes = []
is_sequence = isinstance(generator, Sequence)
if not is_sequence and use_multiprocessing and workers > 1:
Expand Down Expand Up @@ -2384,6 +2398,9 @@ def evaluate_generator(self, generator, steps=None,
'or (x, y). Found: ' +
str(generator_output))
outs = self.test_on_batch(x, y, sample_weight=sample_weight)
if not isinstance(outs, list):
outs = [outs]
outs_per_batch.append(outs)

if isinstance(x, list):
batch_size = x[0].shape[0]
Expand All @@ -2394,7 +2411,6 @@ def evaluate_generator(self, generator, steps=None,
if batch_size == 0:
raise ValueError('Received an empty batch. '
'Batches should at least contain one item.')
all_outs.append(outs)

steps_done += 1
batch_sizes.append(batch_size)
Expand All @@ -2403,15 +2419,16 @@ def evaluate_generator(self, generator, steps=None,
if enqueuer is not None:
enqueuer.stop()

if not isinstance(outs, list):
return np.average(np.asarray(all_outs),
weights=batch_sizes)
else:
averages = []
for i in range(len(outs)):
averages.append(np.average([out[i] for out in all_outs],
averages = []
for i in range(len(outs)):
if i not in stateful_metric_indices:
averages.append(np.average([out[i] for out in outs_per_batch],
weights=batch_sizes))
return averages
else:
averages.append(float(outs_per_batch[-1][i]))
if len(averages) == 1:
return averages[0]
return averages

@interfaces.legacy_generator_methods_support
def predict_generator(self, generator, steps=None,
Expand Down
5 changes: 4 additions & 1 deletion keras/utils/generic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,10 @@ def update(self, current, values=None):
self._values[k][0] += v * (current - self._seen_so_far)
self._values[k][1] += (current - self._seen_so_far)
else:
self._values[k] = v
# Stateful metrics output a numeric value. This representation
# means "take an average from a single value" but keeps the
# numeric formatting.
self._values[k] = [v, 1]
self._seen_so_far = current

now = time.time()
Expand Down
35 changes: 31 additions & 4 deletions tests/keras/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,12 @@ class BinaryTruePositives(keras.layers.Layer):
Assumes predictions and targets of shape `(samples, 1)`.

# Arguments
threshold: Float, lower limit on prediction value that counts as a
positive class prediction.
name: String, name for the metric.
"""

def __init__(self, name='true_positives', **kwargs):
super(BinaryTruePositives, self).__init__(name=name, **kwargs)
self.stateful = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would we have to specify this attribute for each metric?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's the idea, indeed - as long as your metric behaves as a stateful layer.

self.true_positives = K.variable(value=0, dtype='int32')

def reset_states(self):
Expand Down Expand Up @@ -162,11 +161,16 @@ def __call__(self, y_true, y_pred):
loss='binary_crossentropy',
metrics=['acc', metric_fn])

# Test fit, evaluate
samples = 1000
x = np.random.random((samples, 2))
y = np.random.randint(2, size=(samples, 1))
model.fit(x, y, epochs=1, batch_size=10)

val_samples = 10
val_x = np.random.random((val_samples, 2))
val_y = np.random.randint(2, size=(val_samples, 1))

# Test fit and evaluate
history = model.fit(x, y, validation_data=(val_x, val_y), epochs=1, batch_size=10)
outs = model.evaluate(x, y, batch_size=10)
preds = model.predict(x)

Expand All @@ -176,6 +180,29 @@ def ref_true_pos(y_true, y_pred):
# Test correctness (e.g. updates should have been run)
np.testing.assert_allclose(outs[2], ref_true_pos(y, preds), atol=1e-5)

# Test correctness of the validation metric computation
val_preds = model.predict(val_x)
val_outs = model.evaluate(val_x, val_y, batch_size=10)
np.testing.assert_allclose(val_outs[2], ref_true_pos(val_y, val_preds), atol=1e-5)
np.testing.assert_allclose(val_outs[2], history.history['val_true_positives'][-1], atol=1e-5)

# Test with generators
gen = [(np.array([x0]), np.array([y0])) for x0, y0 in zip(x, y)]
val_gen = [(np.array([x0]), np.array([y0])) for x0, y0 in zip(val_x, val_y)]
history = model.fit_generator(iter(gen), epochs=1, steps_per_epoch=samples,
validation_data=iter(val_gen), validation_steps=val_samples)
outs = model.evaluate_generator(iter(gen), steps=samples)
preds = model.predict_generator(iter(gen), steps=samples)

# Test correctness of the metric re ref_true_pos()
np.testing.assert_allclose(outs[2], ref_true_pos(y, preds), atol=1e-5)

# Test correctness of the validation metric computation
val_preds = model.predict_generator(iter(val_gen), steps=val_samples)
val_outs = model.evaluate_generator(iter(val_gen), steps=val_samples)
np.testing.assert_allclose(val_outs[2], ref_true_pos(val_y, val_preds), atol=1e-5)
np.testing.assert_allclose(val_outs[2], history.history['val_true_positives'][-1], atol=1e-5)


if __name__ == '__main__':
pytest.main([__file__])