From 4a8719cf10aff78e120a86a268e6f8c26b52b8b6 Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Mon, 29 Apr 2024 15:18:45 +0300 Subject: [PATCH] Fixed early stopping --- utils/trainer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/utils/trainer.py b/utils/trainer.py index 1d4b2a2..8a86ffa 100644 --- a/utils/trainer.py +++ b/utils/trainer.py @@ -40,7 +40,7 @@ def __init__(self, args): self.optimizer = init_optimizer(args, self.model.parameters()) self.scheduler = init_scheduler(args, self.optimizer, self.train_loader) - self.early_stopping = init_early_stopping(args) + self.early_stopper = init_early_stopping(args) self.logdir = self.init_logdir() self.writer = SummaryWriter(log_dir=f'runs/{self.logdir}') @@ -93,6 +93,8 @@ def run(self): self.post_epoch(metrics) self.write_metrics(epoch, metrics, total_training_time) tbar.set_description(self.epoch_description(metrics)) + if self.early_stopping(metrics): + break except KeyboardInterrupt: pass with open("results.txt", "a") as f: @@ -190,7 +192,6 @@ def post_epoch(self, metrics: dict): self.empty_cache() self.save_checkpoint(metrics) self.scheduler_step(metrics) - self.early_stopping(metrics) def epoch_description(self, metrics): train_acc = round(metrics["Train/Accuracy"], 2) @@ -200,4 +201,4 @@ def epoch_description(self, metrics): def early_stopping(self, metrics): es_metric = metrics[self.es_metric] - self.early_stopping.step(es_metric) + return self.early_stopper.step(es_metric)