Skip to content

Commit

Permalink
Allow re-use of EarlyStopping callback objects. (#3000)
Browse files Browse the repository at this point in the history
An EarlyStopping callback object has internal state variables to tell it
when it has reached its stopping point.  These were initialized in __init__(),
so attempting to re-use the same object resulted in immediate stopping. This
prevents (for example) performing early stopping during cross-validation with
the scikit-learn wrapper.

This patch initializes the variables in on_train_begin(), so they are re-set
for each training fold.  Tests included.
  • Loading branch information
jkleint authored and fchollet committed Jun 18, 2016
1 parent 60e0c96 commit 3513472
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 4 deletions.
8 changes: 4 additions & 4 deletions keras/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,17 +327,17 @@ def __init__(self, monitor='val_loss', patience=0, verbose=0, mode='auto'):

if mode == 'min':
self.monitor_op = np.less
self.best = np.Inf
elif mode == 'max':
self.monitor_op = np.greater
self.best = -np.Inf
else:
if 'acc' in self.monitor:
self.monitor_op = np.greater
self.best = -np.Inf
else:
self.monitor_op = np.less
self.best = np.Inf

def on_train_begin(self, logs={}):
self.wait = 0 # Allow instances to be re-used
self.best = np.Inf if self.monitor_op == np.less else -np.Inf

def on_epoch_end(self, epoch, logs={}):
current = logs.get(self.monitor)
Expand Down
21 changes: 21 additions & 0 deletions tests/keras/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,27 @@ def test_EarlyStopping():
validation_data=(X_test, y_test), callbacks=cbks, nb_epoch=20)


def test_EarlyStopping_reuse():
patience = 3
data = np.random.random((100, 1))
labels = np.where(data > 0.5, 1, 0)
model = Sequential((
Dense(1, input_dim=1, activation='relu'),
Dense(1, activation='sigmoid'),
))
model.compile(optimizer='sgd', loss='binary_crossentropy', metrics=['accuracy'])
stopper = callbacks.EarlyStopping(monitor='acc', patience=patience)
weights = model.get_weights()

hist = model.fit(data, labels, callbacks=[stopper])
assert len(hist.epoch) >= patience

# This should allow training to go for at least `patience` epochs
model.set_weights(weights)
hist = model.fit(data, labels, callbacks=[stopper])
assert len(hist.epoch) >= patience


def test_LearningRateScheduler():
(X_train, y_train), (X_test, y_test) = get_test_data(nb_train=train_samples,
nb_test=test_samples,
Expand Down

0 comments on commit 3513472

Please sign in to comment.