Skip to content

Commit

Permalink
Merge pull request #693 from Red54:master
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 586729162
  • Loading branch information
tensorflower-gardener committed Nov 30, 2023
2 parents 88653d6 + 037919a commit f081455
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
17 changes: 9 additions & 8 deletions tf_keras/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2107,6 +2107,7 @@ def on_epoch_end(self, epoch, logs=None):
if self.restore_best_weights and self.best_weights is None:
# Restore the weights after first epoch if no progress is ever made.
self.best_weights = self.model.get_weights()
self.best_epoch = epoch

self.wait += 1
if self._is_improvement(current, self.best):
Expand All @@ -2126,20 +2127,20 @@ def on_epoch_end(self, epoch, logs=None):
if self.wait >= self.patience and epoch > 0:
self.stopped_epoch = epoch
self.model.stop_training = True
if self.restore_best_weights and self.best_weights is not None:
if self.verbose > 0:
io_utils.print_msg(
"Restoring model weights from "
"the end of the best epoch: "
f"{self.best_epoch + 1}."
)
self.model.set_weights(self.best_weights)

def on_train_end(self, logs=None):
if self.stopped_epoch > 0 and self.verbose > 0:
io_utils.print_msg(
f"Epoch {self.stopped_epoch + 1}: early stopping"
)
if self.restore_best_weights and self.best_weights is not None:
if self.verbose > 0:
io_utils.print_msg(
"Restoring model weights from "
"the end of the best epoch: "
f"{self.best_epoch + 1}."
)
self.model.set_weights(self.best_weights)

def get_monitor_value(self, logs):
logs = logs or {}
Expand Down
2 changes: 2 additions & 0 deletions tf_keras/callbacks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2076,6 +2076,7 @@ def set_weight_to_epoch(self, epoch):
early_stop.on_epoch_end(epoch, logs={"val_loss": losses[epoch]})
if early_stop.model.stop_training:
break
early_stop.on_train_end()
# The best configuration is in epoch 2 (loss = 0.1000),
# and while patience = 2, we're restoring the best weights,
# so we end up at the epoch with the best weights, i.e. epoch 2
Expand All @@ -2099,6 +2100,7 @@ def set_weight_to_epoch(self, epoch):
early_stop.on_epoch_end(epoch, logs={"val_loss": losses[epoch]})
if early_stop.model.stop_training:
break
early_stop.on_train_end()
# No epoch improves on the baseline, so we should train for only 5
# epochs, and restore the second model.
self.assertEqual(epochs_trained, 5)
Expand Down

0 comments on commit f081455

Please sign in to comment.