From 0e42928e2c690074f1e142ab49f3c49c0092550f Mon Sep 17 00:00:00 2001 From: Gopichand Madala Date: Fri, 23 Aug 2024 23:09:11 +0530 Subject: [PATCH] Bug fixes & added more callbacks --- nbs/04_learner.ipynb | 98 ++++++++++++++++++++++++++++++++++++++++++-- rapidai/_modidx.py | 59 +++++++++++++++++--------- rapidai/learner.py | 98 ++++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 229 insertions(+), 26 deletions(-) diff --git a/nbs/04_learner.ipynb b/nbs/04_learner.ipynb index 4734044..a142b07 100755 --- a/nbs/04_learner.ipynb +++ b/nbs/04_learner.ipynb @@ -43,7 +43,8 @@ "from torch import optim\n", "import torch.nn.functional as F\n", "from rapidai.conv import *\n", - "from fastprogress import progress_bar,master_bar" + "from fastprogress import progress_bar,master_bar\n", + "from torch.utils.tensorboard import SummaryWriter" ] }, { @@ -1402,7 +1403,7 @@ "outputs": [], "source": [ "#|export\n", - "class WandbCallback(Callback):\n", + "class WandbCB(Callback):\n", " def __init__(\n", " self,\n", " project_name: str,\n", @@ -1506,7 +1507,7 @@ "\n", "\n", "\n", - "class MLflowCallback(Callback):\n", + "class MLflowCB(Callback):\n", " def __init__(\n", " self,\n", " experiment_name: str,\n", @@ -1597,7 +1598,96 @@ "\n", " def after_fit(self, learn):\n", " # Finish MLflow run\n", - " mlflow.end_run()" + " mlflow.end_run()\n", + "\n", + "\n", + "\n", + "class EarlyStoppingCB(Callback):\n", + " def __init__(self, monitor='val_loss', min_delta=0, patience=3, mode='min'):\n", + " fc.store_attr()\n", + " self.best = None\n", + " self.num_bad_epochs = 0\n", + " self.operator = torch.lt if mode == 'min' else torch.gt\n", + "\n", + " def after_epoch(self, learn):\n", + " current = learn.metrics.metrics[self.monitor].compute().item()\n", + " if self.best is None or self.operator(current, self.best - self.min_delta):\n", + " self.best = current\n", + " self.num_bad_epochs = 0\n", + " else:\n", + " self.num_bad_epochs += 1\n", + " if self.num_bad_epochs >= self.patience:\n", + " print(\"Stopping early!\")\n", + " raise CancelFitException()\n", + "\n", + "\n", + "class TensorBoardCB(Callback):\n", + " def __init__(self, log_dir='./runs', log_graph=True):\n", + " self.writer = SummaryWriter(log_dir=log_dir)\n", + " self.log_graph = log_graph\n", + "\n", + " def before_fit(self, learn):\n", + " if self.log_graph:\n", + " dummy_input = next(iter(learn.dls.train))[0].to(learn.model.device)\n", + " self.writer.add_graph(learn.model, dummy_input)\n", + "\n", + " def after_batch(self, learn):\n", + " if learn.training:\n", + " self.writer.add_scalar('Loss/train', learn.loss.item(), learn.iter_total)\n", + " else:\n", + " self.writer.add_scalar('Loss/val', learn.loss.item(), learn.iter_total)\n", + "\n", + " def after_epoch(self, learn):\n", + " for name, metric in learn.metrics.metrics.items():\n", + " self.writer.add_scalar(f'Metrics/{name}', metric.compute().item(), learn.epoch)\n", + "\n", + " def after_fit(self, learn):\n", + " self.writer.close()\n", + "\n", + "\n", + "class ReduceLROnPlateauCB(Callback):\n", + " def __init__(self, monitor='val_loss', patience=3, factor=0.1, min_lr=1e-6, mode='min'):\n", + " self.monitor, self.patience, self.factor, self.min_lr, self.mode = monitor, patience, factor, min_lr, mode\n", + " self.best = None\n", + " self.counter = 0\n", + "\n", + " def before_fit(self, learn):\n", + " self.best = float('inf') if self.mode == 'min' else -float('inf')\n", + "\n", + " def after_epoch(self, learn):\n", + " current = learn.metrics.all_metrics[self.monitor].compute().item()\n", + " if (self.mode == 'min' and current < self.best) or \\\n", + " (self.mode == 'max' and current > self.best):\n", + " self.best = current\n", + " self.counter = 0\n", + " else:\n", + " self.counter += 1\n", + " if self.counter >= self.patience:\n", + " new_lr = max(learn.opt.param_groups[0]['lr'] * self.factor, self.min_lr)\n", + " for g in learn.opt.param_groups:\n", + " g['lr'] = new_lr\n", + " self.counter = 0\n", + " print(f\"Learning rate reduced to {new_lr:.1e}\")\n", + "\n", + "\n", + "\n", + "class ModelCheckpointCB(Callback):\n", + " def __init__(self, monitor='val_loss', mode='min', save_best_only=True, filepath='./best_model.pth'):\n", + " self.monitor, self.mode, self.save_best_only, self.filepath = monitor, mode, save_best_only, filepath\n", + " self.best = None\n", + "\n", + " def before_fit(self, learn):\n", + " self.best = float('inf') if self.mode == 'min' else -float('inf')\n", + "\n", + " def after_epoch(self, learn):\n", + " current = learn.metrics.all_metrics[self.monitor].compute().item()\n", + " if (self.mode == 'min' and current < self.best) or \\\n", + " (self.mode == 'max' and current > self.best):\n", + " self.best = current\n", + " torch.save(learn.model.state_dict(), self.filepath)\n", + " print(f\"Saved model checkpoint to {self.filepath}\")\n", + "\n", + "\n" ] }, { diff --git a/rapidai/_modidx.py b/rapidai/_modidx.py index e39618b..46479a8 100644 --- a/rapidai/_modidx.py +++ b/rapidai/_modidx.py @@ -165,6 +165,11 @@ 'rapidai.learner.DeviceCB.__init__': ('learner.html#devicecb.__init__', 'rapidai/learner.py'), 'rapidai.learner.DeviceCB.before_batch': ('learner.html#devicecb.before_batch', 'rapidai/learner.py'), 'rapidai.learner.DeviceCB.before_fit': ('learner.html#devicecb.before_fit', 'rapidai/learner.py'), + 'rapidai.learner.EarlyStoppingCB': ('learner.html#earlystoppingcb', 'rapidai/learner.py'), + 'rapidai.learner.EarlyStoppingCB.__init__': ( 'learner.html#earlystoppingcb.__init__', + 'rapidai/learner.py'), + 'rapidai.learner.EarlyStoppingCB.after_epoch': ( 'learner.html#earlystoppingcb.after_epoch', + 'rapidai/learner.py'), 'rapidai.learner.LRFinderCB': ('learner.html#lrfindercb', 'rapidai/learner.py'), 'rapidai.learner.LRFinderCB.__init__': ('learner.html#lrfindercb.__init__', 'rapidai/learner.py'), 'rapidai.learner.LRFinderCB.after_batch': ('learner.html#lrfindercb.after_batch', 'rapidai/learner.py'), @@ -180,16 +185,12 @@ 'rapidai.learner.Learner.fit': ('learner.html#learner.fit', 'rapidai/learner.py'), 'rapidai.learner.Learner.one_epoch': ('learner.html#learner.one_epoch', 'rapidai/learner.py'), 'rapidai.learner.Learner.training': ('learner.html#learner.training', 'rapidai/learner.py'), - 'rapidai.learner.MLflowCallback': ('learner.html#mlflowcallback', 'rapidai/learner.py'), - 'rapidai.learner.MLflowCallback.__init__': ('learner.html#mlflowcallback.__init__', 'rapidai/learner.py'), - 'rapidai.learner.MLflowCallback.after_batch': ( 'learner.html#mlflowcallback.after_batch', - 'rapidai/learner.py'), - 'rapidai.learner.MLflowCallback.after_epoch': ( 'learner.html#mlflowcallback.after_epoch', - 'rapidai/learner.py'), - 'rapidai.learner.MLflowCallback.after_fit': ( 'learner.html#mlflowcallback.after_fit', - 'rapidai/learner.py'), - 'rapidai.learner.MLflowCallback.before_fit': ( 'learner.html#mlflowcallback.before_fit', - 'rapidai/learner.py'), + 'rapidai.learner.MLflowCB': ('learner.html#mlflowcb', 'rapidai/learner.py'), + 'rapidai.learner.MLflowCB.__init__': ('learner.html#mlflowcb.__init__', 'rapidai/learner.py'), + 'rapidai.learner.MLflowCB.after_batch': ('learner.html#mlflowcb.after_batch', 'rapidai/learner.py'), + 'rapidai.learner.MLflowCB.after_epoch': ('learner.html#mlflowcb.after_epoch', 'rapidai/learner.py'), + 'rapidai.learner.MLflowCB.after_fit': ('learner.html#mlflowcb.after_fit', 'rapidai/learner.py'), + 'rapidai.learner.MLflowCB.before_fit': ('learner.html#mlflowcb.before_fit', 'rapidai/learner.py'), 'rapidai.learner.MetricsCB': ('learner.html#metricscb', 'rapidai/learner.py'), 'rapidai.learner.MetricsCB.__init__': ('learner.html#metricscb.__init__', 'rapidai/learner.py'), 'rapidai.learner.MetricsCB._log': ('learner.html#metricscb._log', 'rapidai/learner.py'), @@ -197,6 +198,13 @@ 'rapidai.learner.MetricsCB.after_epoch': ('learner.html#metricscb.after_epoch', 'rapidai/learner.py'), 'rapidai.learner.MetricsCB.before_epoch': ('learner.html#metricscb.before_epoch', 'rapidai/learner.py'), 'rapidai.learner.MetricsCB.before_fit': ('learner.html#metricscb.before_fit', 'rapidai/learner.py'), + 'rapidai.learner.ModelCheckpointCB': ('learner.html#modelcheckpointcb', 'rapidai/learner.py'), + 'rapidai.learner.ModelCheckpointCB.__init__': ( 'learner.html#modelcheckpointcb.__init__', + 'rapidai/learner.py'), + 'rapidai.learner.ModelCheckpointCB.after_epoch': ( 'learner.html#modelcheckpointcb.after_epoch', + 'rapidai/learner.py'), + 'rapidai.learner.ModelCheckpointCB.before_fit': ( 'learner.html#modelcheckpointcb.before_fit', + 'rapidai/learner.py'), 'rapidai.learner.MomentumLearner': ('learner.html#momentumlearner', 'rapidai/learner.py'), 'rapidai.learner.MomentumLearner.__init__': ( 'learner.html#momentumlearner.__init__', 'rapidai/learner.py'), @@ -209,9 +217,25 @@ 'rapidai.learner.ProgressCB.after_epoch': ('learner.html#progresscb.after_epoch', 'rapidai/learner.py'), 'rapidai.learner.ProgressCB.before_epoch': ('learner.html#progresscb.before_epoch', 'rapidai/learner.py'), 'rapidai.learner.ProgressCB.before_fit': ('learner.html#progresscb.before_fit', 'rapidai/learner.py'), + 'rapidai.learner.ReduceLROnPlateauCB': ('learner.html#reducelronplateaucb', 'rapidai/learner.py'), + 'rapidai.learner.ReduceLROnPlateauCB.__init__': ( 'learner.html#reducelronplateaucb.__init__', + 'rapidai/learner.py'), + 'rapidai.learner.ReduceLROnPlateauCB.after_epoch': ( 'learner.html#reducelronplateaucb.after_epoch', + 'rapidai/learner.py'), + 'rapidai.learner.ReduceLROnPlateauCB.before_fit': ( 'learner.html#reducelronplateaucb.before_fit', + 'rapidai/learner.py'), 'rapidai.learner.SingleBatchCB': ('learner.html#singlebatchcb', 'rapidai/learner.py'), 'rapidai.learner.SingleBatchCB.after_batch': ( 'learner.html#singlebatchcb.after_batch', 'rapidai/learner.py'), + 'rapidai.learner.TensorBoardCB': ('learner.html#tensorboardcb', 'rapidai/learner.py'), + 'rapidai.learner.TensorBoardCB.__init__': ('learner.html#tensorboardcb.__init__', 'rapidai/learner.py'), + 'rapidai.learner.TensorBoardCB.after_batch': ( 'learner.html#tensorboardcb.after_batch', + 'rapidai/learner.py'), + 'rapidai.learner.TensorBoardCB.after_epoch': ( 'learner.html#tensorboardcb.after_epoch', + 'rapidai/learner.py'), + 'rapidai.learner.TensorBoardCB.after_fit': ('learner.html#tensorboardcb.after_fit', 'rapidai/learner.py'), + 'rapidai.learner.TensorBoardCB.before_fit': ( 'learner.html#tensorboardcb.before_fit', + 'rapidai/learner.py'), 'rapidai.learner.TrainCB': ('learner.html#traincb', 'rapidai/learner.py'), 'rapidai.learner.TrainCB.__init__': ('learner.html#traincb.__init__', 'rapidai/learner.py'), 'rapidai.learner.TrainCB.backward': ('learner.html#traincb.backward', 'rapidai/learner.py'), @@ -225,15 +249,12 @@ 'rapidai.learner.TrainLearner.predict': ('learner.html#trainlearner.predict', 'rapidai/learner.py'), 'rapidai.learner.TrainLearner.step': ('learner.html#trainlearner.step', 'rapidai/learner.py'), 'rapidai.learner.TrainLearner.zero_grad': ('learner.html#trainlearner.zero_grad', 'rapidai/learner.py'), - 'rapidai.learner.WandbCallback': ('learner.html#wandbcallback', 'rapidai/learner.py'), - 'rapidai.learner.WandbCallback.__init__': ('learner.html#wandbcallback.__init__', 'rapidai/learner.py'), - 'rapidai.learner.WandbCallback.after_batch': ( 'learner.html#wandbcallback.after_batch', - 'rapidai/learner.py'), - 'rapidai.learner.WandbCallback.after_epoch': ( 'learner.html#wandbcallback.after_epoch', - 'rapidai/learner.py'), - 'rapidai.learner.WandbCallback.after_fit': ('learner.html#wandbcallback.after_fit', 'rapidai/learner.py'), - 'rapidai.learner.WandbCallback.before_fit': ( 'learner.html#wandbcallback.before_fit', - 'rapidai/learner.py'), + 'rapidai.learner.WandbCB': ('learner.html#wandbcb', 'rapidai/learner.py'), + 'rapidai.learner.WandbCB.__init__': ('learner.html#wandbcb.__init__', 'rapidai/learner.py'), + 'rapidai.learner.WandbCB.after_batch': ('learner.html#wandbcb.after_batch', 'rapidai/learner.py'), + 'rapidai.learner.WandbCB.after_epoch': ('learner.html#wandbcb.after_epoch', 'rapidai/learner.py'), + 'rapidai.learner.WandbCB.after_fit': ('learner.html#wandbcb.after_fit', 'rapidai/learner.py'), + 'rapidai.learner.WandbCB.before_fit': ('learner.html#wandbcb.before_fit', 'rapidai/learner.py'), 'rapidai.learner.lr_find': ('learner.html#lr_find', 'rapidai/learner.py'), 'rapidai.learner.run_cbs': ('learner.html#run_cbs', 'rapidai/learner.py'), 'rapidai.learner.to_cpu': ('learner.html#to_cpu', 'rapidai/learner.py'), diff --git a/rapidai/learner.py b/rapidai/learner.py index a114db2..9a9247e 100755 --- a/rapidai/learner.py +++ b/rapidai/learner.py @@ -3,7 +3,8 @@ # %% auto 0 __all__ = ['CancelFitException', 'CancelBatchException', 'CancelEpochException', 'Callback', 'run_cbs', 'SingleBatchCB', 'to_cpu', 'MetricsCB', 'DeviceCB', 'TrainCB', 'ProgressCB', 'with_cbs', 'Learner', 'TrainLearner', 'MomentumLearner', - 'LRFinderCB', 'lr_find', 'WandbCallback', 'MLflowCallback'] + 'LRFinderCB', 'lr_find', 'WandbCB', 'MLflowCB', 'EarlyStoppingCB', 'TensorBoardCB', 'ReduceLROnPlateauCB', + 'ModelCheckpointCB'] # %% ../nbs/04_learner.ipynb 1 import math,torch,matplotlib.pyplot as plt @@ -20,6 +21,7 @@ import torch.nn.functional as F from .conv import * from fastprogress import progress_bar,master_bar +from torch.utils.tensorboard import SummaryWriter # %% ../nbs/04_learner.ipynb 14 class CancelFitException(Exception): pass @@ -239,7 +241,7 @@ def lr_find(self:Learner, gamma=1.3, max_mult=3, start_lr=1e-5, max_epochs=10): self.fit(max_epochs, lr=start_lr, cbs=LRFinderCB(gamma=gamma, max_mult=max_mult)) # %% ../nbs/04_learner.ipynb 65 -class WandbCallback(Callback): +class WandbCB(Callback): def __init__( self, project_name: str, @@ -343,7 +345,7 @@ def after_fit(self, learn): -class MLflowCallback(Callback): +class MLflowCB(Callback): def __init__( self, experiment_name: str, @@ -435,3 +437,93 @@ def after_epoch(self, learn): def after_fit(self, learn): # Finish MLflow run mlflow.end_run() + + + +class EarlyStoppingCB(Callback): + def __init__(self, monitor='val_loss', min_delta=0, patience=3, mode='min'): + fc.store_attr() + self.best = None + self.num_bad_epochs = 0 + self.operator = torch.lt if mode == 'min' else torch.gt + + def after_epoch(self, learn): + current = learn.metrics.metrics[self.monitor].compute().item() + if self.best is None or self.operator(current, self.best - self.min_delta): + self.best = current + self.num_bad_epochs = 0 + else: + self.num_bad_epochs += 1 + if self.num_bad_epochs >= self.patience: + print("Stopping early!") + raise CancelFitException() + + +class TensorBoardCB(Callback): + def __init__(self, log_dir='./runs', log_graph=True): + self.writer = SummaryWriter(log_dir=log_dir) + self.log_graph = log_graph + + def before_fit(self, learn): + if self.log_graph: + dummy_input = next(iter(learn.dls.train))[0].to(learn.model.device) + self.writer.add_graph(learn.model, dummy_input) + + def after_batch(self, learn): + if learn.training: + self.writer.add_scalar('Loss/train', learn.loss.item(), learn.iter_total) + else: + self.writer.add_scalar('Loss/val', learn.loss.item(), learn.iter_total) + + def after_epoch(self, learn): + for name, metric in learn.metrics.metrics.items(): + self.writer.add_scalar(f'Metrics/{name}', metric.compute().item(), learn.epoch) + + def after_fit(self, learn): + self.writer.close() + + +class ReduceLROnPlateauCB(Callback): + def __init__(self, monitor='val_loss', patience=3, factor=0.1, min_lr=1e-6, mode='min'): + self.monitor, self.patience, self.factor, self.min_lr, self.mode = monitor, patience, factor, min_lr, mode + self.best = None + self.counter = 0 + + def before_fit(self, learn): + self.best = float('inf') if self.mode == 'min' else -float('inf') + + def after_epoch(self, learn): + current = learn.metrics.all_metrics[self.monitor].compute().item() + if (self.mode == 'min' and current < self.best) or \ + (self.mode == 'max' and current > self.best): + self.best = current + self.counter = 0 + else: + self.counter += 1 + if self.counter >= self.patience: + new_lr = max(learn.opt.param_groups[0]['lr'] * self.factor, self.min_lr) + for g in learn.opt.param_groups: + g['lr'] = new_lr + self.counter = 0 + print(f"Learning rate reduced to {new_lr:.1e}") + + + +class ModelCheckpointCB(Callback): + def __init__(self, monitor='val_loss', mode='min', save_best_only=True, filepath='./best_model.pth'): + self.monitor, self.mode, self.save_best_only, self.filepath = monitor, mode, save_best_only, filepath + self.best = None + + def before_fit(self, learn): + self.best = float('inf') if self.mode == 'min' else -float('inf') + + def after_epoch(self, learn): + current = learn.metrics.all_metrics[self.monitor].compute().item() + if (self.mode == 'min' and current < self.best) or \ + (self.mode == 'max' and current > self.best): + self.best = current + torch.save(learn.model.state_dict(), self.filepath) + print(f"Saved model checkpoint to {self.filepath}") + + +