diff --git a/keras/callbacks.py b/keras/callbacks.py index ff4ff222740..6ba4b065bc2 100644 --- a/keras/callbacks.py +++ b/keras/callbacks.py @@ -327,17 +327,17 @@ def __init__(self, monitor='val_loss', patience=0, verbose=0, mode='auto'): if mode == 'min': self.monitor_op = np.less - self.best = np.Inf elif mode == 'max': self.monitor_op = np.greater - self.best = -np.Inf else: if 'acc' in self.monitor: self.monitor_op = np.greater - self.best = -np.Inf else: self.monitor_op = np.less - self.best = np.Inf + + def on_train_begin(self, logs={}): + self.wait = 0 # Allow instances to be re-used + self.best = np.Inf if self.monitor_op == np.less else -np.Inf def on_epoch_end(self, epoch, logs={}): current = logs.get(self.monitor) diff --git a/tests/keras/test_callbacks.py b/tests/keras/test_callbacks.py index 069369d308a..f36e1a9b795 100644 --- a/tests/keras/test_callbacks.py +++ b/tests/keras/test_callbacks.py @@ -105,6 +105,27 @@ def test_EarlyStopping(): validation_data=(X_test, y_test), callbacks=cbks, nb_epoch=20) +def test_EarlyStopping_reuse(): + patience = 3 + data = np.random.random((100, 1)) + labels = np.where(data > 0.5, 1, 0) + model = Sequential(( + Dense(1, input_dim=1, activation='relu'), + Dense(1, activation='sigmoid'), + )) + model.compile(optimizer='sgd', loss='binary_crossentropy', metrics=['accuracy']) + stopper = callbacks.EarlyStopping(monitor='acc', patience=patience) + weights = model.get_weights() + + hist = model.fit(data, labels, callbacks=[stopper]) + assert len(hist.epoch) >= patience + + # This should allow training to go for at least `patience` epochs + model.set_weights(weights) + hist = model.fit(data, labels, callbacks=[stopper]) + assert len(hist.epoch) >= patience + + def test_LearningRateScheduler(): (X_train, y_train), (X_test, y_test) = get_test_data(nb_train=train_samples, nb_test=test_samples,