diff --git a/nbs/04_learner.ipynb b/nbs/04_learner.ipynb index dea9cdb..4fdc5a8 100755 --- a/nbs/04_learner.ipynb +++ b/nbs/04_learner.ipynb @@ -1452,7 +1452,30 @@ " if self.log_model:\n", " torch.save(learn.model.state_dict(), self.model_name)\n", " wandb.save(self.model_name)\n", - " self.run.finish()\n" + " self.run.finish()\n", + "\n", + "\n", + "class EarlyStoppingCallback(Callback):\n", + " order = MetricsCB.order+1\n", + " def __init__(self, patience=3, min_delta=0.001):\n", + " self.patience = patience\n", + " self.min_delta = min_delta\n", + " self.best_loss = float('inf')\n", + " self.wait = 0\n", + " \n", + " def after_epoch(self, learn):\n", + " if not learn.model.training:\n", + " current_loss = learn.metrics.all_metrics['loss'].compute()\n", + " if current_loss < self.best_loss - self.min_delta:\n", + " self.best_loss = current_loss\n", + " self.wait = 0\n", + " else:\n", + " self.wait += 1\n", + " \n", + " if self.wait >= self.patience:\n", + " print(\"Early stopping triggered.\")\n", + " raise CancelFitException\n", + "\n" ] }, { diff --git a/rapidai/_modidx.py b/rapidai/_modidx.py index f229b37..a8c3ed2 100644 --- a/rapidai/_modidx.py +++ b/rapidai/_modidx.py @@ -165,6 +165,11 @@ 'rapidai.learner.DeviceCB.__init__': ('learner.html#devicecb.__init__', 'rapidai/learner.py'), 'rapidai.learner.DeviceCB.before_batch': ('learner.html#devicecb.before_batch', 'rapidai/learner.py'), 'rapidai.learner.DeviceCB.before_fit': ('learner.html#devicecb.before_fit', 'rapidai/learner.py'), + 'rapidai.learner.EarlyStoppingCallback': ('learner.html#earlystoppingcallback', 'rapidai/learner.py'), + 'rapidai.learner.EarlyStoppingCallback.__init__': ( 'learner.html#earlystoppingcallback.__init__', + 'rapidai/learner.py'), + 'rapidai.learner.EarlyStoppingCallback.after_epoch': ( 'learner.html#earlystoppingcallback.after_epoch', + 'rapidai/learner.py'), 'rapidai.learner.LRFinderCB': ('learner.html#lrfindercb', 'rapidai/learner.py'), 'rapidai.learner.LRFinderCB.__init__': ('learner.html#lrfindercb.__init__', 'rapidai/learner.py'), 'rapidai.learner.LRFinderCB.after_batch': ('learner.html#lrfindercb.after_batch', 'rapidai/learner.py'), diff --git a/rapidai/learner.py b/rapidai/learner.py index 486539c..7f657bb 100755 --- a/rapidai/learner.py +++ b/rapidai/learner.py @@ -3,7 +3,7 @@ # %% auto 0 __all__ = ['CancelFitException', 'CancelBatchException', 'CancelEpochException', 'Callback', 'run_cbs', 'SingleBatchCB', 'to_cpu', 'MetricsCB', 'DeviceCB', 'TrainCB', 'ProgressCB', 'with_cbs', 'Learner', 'TrainLearner', 'MomentumLearner', - 'LRFinderCB', 'lr_find', 'WandbCB'] + 'LRFinderCB', 'lr_find', 'WandbCB', 'EarlyStoppingCallback'] # %% ../nbs/04_learner.ipynb 1 import math,torch,matplotlib.pyplot as plt @@ -279,3 +279,26 @@ def cleanup_fit(self, learn): wandb.save(self.model_name) self.run.finish() + +class EarlyStoppingCallback(Callback): + order = MetricsCB.order+1 + def __init__(self, patience=3, min_delta=0.001): + self.patience = patience + self.min_delta = min_delta + self.best_loss = float('inf') + self.wait = 0 + + def after_epoch(self, learn): + if not learn.model.training: + current_loss = learn.metrics.all_metrics['loss'].compute() + if current_loss < self.best_loss - self.min_delta: + self.best_loss = current_loss + self.wait = 0 + else: + self.wait += 1 + + if self.wait >= self.patience: + print("Early stopping triggered.") + raise CancelFitException + +