-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
General stateful metrics fixes #9446
Merged
fchollet
merged 10 commits into
keras-team:master
from
rossumai:stateful-metrics-sanity
Mar 22, 2018
Merged
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
15d36ff
Require stateful metrics layers to be actually stateful
pasky 27e39cc
Prevent stateful metrics to leak np.floats to the History object
pasky 2e30a38
Progbar: Format stateful metrics values as floats alike other metrics
pasky 7f4addc
test_stateful_metrics: Also test validation set evaluation
pasky 9fc9a2e
Add support for stateful metrics in fit_generator() and evaluate_gene…
pasky de23bd5
Document stateful metrics
pasky ff2a790
evaluate_generator(): Do not leak np.float to History here either
pasky 9b27400
Revert stateful metrics documentation until the API stabilizes
pasky 2b91891
Progbar: Explain stateful metrics handling
pasky db4af05
Model.evaluate_generator(): More consistent stateful metrics handling
pasky File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -116,13 +116,12 @@ class BinaryTruePositives(keras.layers.Layer): | |
Assumes predictions and targets of shape `(samples, 1)`. | ||
|
||
# Arguments | ||
threshold: Float, lower limit on prediction value that counts as a | ||
positive class prediction. | ||
name: String, name for the metric. | ||
""" | ||
|
||
def __init__(self, name='true_positives', **kwargs): | ||
super(BinaryTruePositives, self).__init__(name=name, **kwargs) | ||
self.stateful = True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would we have to specify this attribute for each metric? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's the idea, indeed - as long as your metric behaves as a stateful layer. |
||
self.true_positives = K.variable(value=0, dtype='int32') | ||
|
||
def reset_states(self): | ||
|
@@ -162,11 +161,16 @@ def __call__(self, y_true, y_pred): | |
loss='binary_crossentropy', | ||
metrics=['acc', metric_fn]) | ||
|
||
# Test fit, evaluate | ||
samples = 1000 | ||
x = np.random.random((samples, 2)) | ||
y = np.random.randint(2, size=(samples, 1)) | ||
model.fit(x, y, epochs=1, batch_size=10) | ||
|
||
val_samples = 10 | ||
val_x = np.random.random((val_samples, 2)) | ||
val_y = np.random.randint(2, size=(val_samples, 1)) | ||
|
||
# Test fit and evaluate | ||
history = model.fit(x, y, validation_data=(val_x, val_y), epochs=1, batch_size=10) | ||
outs = model.evaluate(x, y, batch_size=10) | ||
preds = model.predict(x) | ||
|
||
|
@@ -176,6 +180,29 @@ def ref_true_pos(y_true, y_pred): | |
# Test correctness (e.g. updates should have been run) | ||
np.testing.assert_allclose(outs[2], ref_true_pos(y, preds), atol=1e-5) | ||
|
||
# Test correctness of the validation metric computation | ||
val_preds = model.predict(val_x) | ||
val_outs = model.evaluate(val_x, val_y, batch_size=10) | ||
np.testing.assert_allclose(val_outs[2], ref_true_pos(val_y, val_preds), atol=1e-5) | ||
np.testing.assert_allclose(val_outs[2], history.history['val_true_positives'][-1], atol=1e-5) | ||
|
||
# Test with generators | ||
gen = [(np.array([x0]), np.array([y0])) for x0, y0 in zip(x, y)] | ||
val_gen = [(np.array([x0]), np.array([y0])) for x0, y0 in zip(val_x, val_y)] | ||
history = model.fit_generator(iter(gen), epochs=1, steps_per_epoch=samples, | ||
validation_data=iter(val_gen), validation_steps=val_samples) | ||
outs = model.evaluate_generator(iter(gen), steps=samples) | ||
preds = model.predict_generator(iter(gen), steps=samples) | ||
|
||
# Test correctness of the metric re ref_true_pos() | ||
np.testing.assert_allclose(outs[2], ref_true_pos(y, preds), atol=1e-5) | ||
|
||
# Test correctness of the validation metric computation | ||
val_preds = model.predict_generator(iter(val_gen), steps=val_samples) | ||
val_outs = model.evaluate_generator(iter(val_gen), steps=val_samples) | ||
np.testing.assert_allclose(val_outs[2], ref_true_pos(val_y, val_preds), atol=1e-5) | ||
np.testing.assert_allclose(val_outs[2], history.history['val_true_positives'][-1], atol=1e-5) | ||
|
||
|
||
if __name__ == '__main__': | ||
pytest.main([__file__]) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This reset is needed to make Stateful Metrics work for generators.
How do you feel about spinning this bug fix out into a separate PR? Should be a quick approval.
Some of the other changes, for instance
m.stateful
will likely have some discussion. I really want this bug fix to make it into the next release :)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mind if I raise the reset in evaluate generator to get it through?
Don't want to steal your thunder.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No worries about that! But I'm going to take better care of the MR from now on and update it in a couple of minutes, sorry about the delays!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like you got it, I just wanted to make sure this PR wasn't going to get stuck for another couple of releases.