Skip to content

Commit

Permalink
Fixed early stopping
Browse files Browse the repository at this point in the history
  • Loading branch information
ancestor-mithril committed Apr 29, 2024
1 parent 2dc4c34 commit 4a8719c
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, args):
self.optimizer = init_optimizer(args, self.model.parameters())

self.scheduler = init_scheduler(args, self.optimizer, self.train_loader)
self.early_stopping = init_early_stopping(args)
self.early_stopper = init_early_stopping(args)

self.logdir = self.init_logdir()
self.writer = SummaryWriter(log_dir=f'runs/{self.logdir}')
Expand Down Expand Up @@ -93,6 +93,8 @@ def run(self):
self.post_epoch(metrics)
self.write_metrics(epoch, metrics, total_training_time)
tbar.set_description(self.epoch_description(metrics))
if self.early_stopping(metrics):
break
except KeyboardInterrupt:
pass
with open("results.txt", "a") as f:
Expand Down Expand Up @@ -190,7 +192,6 @@ def post_epoch(self, metrics: dict):
self.empty_cache()
self.save_checkpoint(metrics)
self.scheduler_step(metrics)
self.early_stopping(metrics)

def epoch_description(self, metrics):
train_acc = round(metrics["Train/Accuracy"], 2)
Expand All @@ -200,4 +201,4 @@ def epoch_description(self, metrics):

def early_stopping(self, metrics):
es_metric = metrics[self.es_metric]
self.early_stopping.step(es_metric)
return self.early_stopper.step(es_metric)

0 comments on commit 4a8719c

Please sign in to comment.