Skip to content

Commit

Permalink
Bug fixes & added more callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
m-gopichand committed Aug 23, 2024
1 parent c40cb8d commit 0e42928
Show file tree
Hide file tree
Showing 3 changed files with 229 additions and 26 deletions.
98 changes: 94 additions & 4 deletions nbs/04_learner.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@
"from torch import optim\n",
"import torch.nn.functional as F\n",
"from rapidai.conv import *\n",
"from fastprogress import progress_bar,master_bar"
"from fastprogress import progress_bar,master_bar\n",
"from torch.utils.tensorboard import SummaryWriter"
]
},
{
Expand Down Expand Up @@ -1402,7 +1403,7 @@
"outputs": [],
"source": [
"#|export\n",
"class WandbCallback(Callback):\n",
"class WandbCB(Callback):\n",
" def __init__(\n",
" self,\n",
" project_name: str,\n",
Expand Down Expand Up @@ -1506,7 +1507,7 @@
"\n",
"\n",
"\n",
"class MLflowCallback(Callback):\n",
"class MLflowCB(Callback):\n",
" def __init__(\n",
" self,\n",
" experiment_name: str,\n",
Expand Down Expand Up @@ -1597,7 +1598,96 @@
"\n",
" def after_fit(self, learn):\n",
" # Finish MLflow run\n",
" mlflow.end_run()"
" mlflow.end_run()\n",
"\n",
"\n",
"\n",
"class EarlyStoppingCB(Callback):\n",
" def __init__(self, monitor='val_loss', min_delta=0, patience=3, mode='min'):\n",
" fc.store_attr()\n",
" self.best = None\n",
" self.num_bad_epochs = 0\n",
" self.operator = torch.lt if mode == 'min' else torch.gt\n",
"\n",
" def after_epoch(self, learn):\n",
" current = learn.metrics.metrics[self.monitor].compute().item()\n",
" if self.best is None or self.operator(current, self.best - self.min_delta):\n",
" self.best = current\n",
" self.num_bad_epochs = 0\n",
" else:\n",
" self.num_bad_epochs += 1\n",
" if self.num_bad_epochs >= self.patience:\n",
" print(\"Stopping early!\")\n",
" raise CancelFitException()\n",
"\n",
"\n",
"class TensorBoardCB(Callback):\n",
" def __init__(self, log_dir='./runs', log_graph=True):\n",
" self.writer = SummaryWriter(log_dir=log_dir)\n",
" self.log_graph = log_graph\n",
"\n",
" def before_fit(self, learn):\n",
" if self.log_graph:\n",
" dummy_input = next(iter(learn.dls.train))[0].to(learn.model.device)\n",
" self.writer.add_graph(learn.model, dummy_input)\n",
"\n",
" def after_batch(self, learn):\n",
" if learn.training:\n",
" self.writer.add_scalar('Loss/train', learn.loss.item(), learn.iter_total)\n",
" else:\n",
" self.writer.add_scalar('Loss/val', learn.loss.item(), learn.iter_total)\n",
"\n",
" def after_epoch(self, learn):\n",
" for name, metric in learn.metrics.metrics.items():\n",
" self.writer.add_scalar(f'Metrics/{name}', metric.compute().item(), learn.epoch)\n",
"\n",
" def after_fit(self, learn):\n",
" self.writer.close()\n",
"\n",
"\n",
"class ReduceLROnPlateauCB(Callback):\n",
" def __init__(self, monitor='val_loss', patience=3, factor=0.1, min_lr=1e-6, mode='min'):\n",
" self.monitor, self.patience, self.factor, self.min_lr, self.mode = monitor, patience, factor, min_lr, mode\n",
" self.best = None\n",
" self.counter = 0\n",
"\n",
" def before_fit(self, learn):\n",
" self.best = float('inf') if self.mode == 'min' else -float('inf')\n",
"\n",
" def after_epoch(self, learn):\n",
" current = learn.metrics.all_metrics[self.monitor].compute().item()\n",
" if (self.mode == 'min' and current < self.best) or \\\n",
" (self.mode == 'max' and current > self.best):\n",
" self.best = current\n",
" self.counter = 0\n",
" else:\n",
" self.counter += 1\n",
" if self.counter >= self.patience:\n",
" new_lr = max(learn.opt.param_groups[0]['lr'] * self.factor, self.min_lr)\n",
" for g in learn.opt.param_groups:\n",
" g['lr'] = new_lr\n",
" self.counter = 0\n",
" print(f\"Learning rate reduced to {new_lr:.1e}\")\n",
"\n",
"\n",
"\n",
"class ModelCheckpointCB(Callback):\n",
" def __init__(self, monitor='val_loss', mode='min', save_best_only=True, filepath='./best_model.pth'):\n",
" self.monitor, self.mode, self.save_best_only, self.filepath = monitor, mode, save_best_only, filepath\n",
" self.best = None\n",
"\n",
" def before_fit(self, learn):\n",
" self.best = float('inf') if self.mode == 'min' else -float('inf')\n",
"\n",
" def after_epoch(self, learn):\n",
" current = learn.metrics.all_metrics[self.monitor].compute().item()\n",
" if (self.mode == 'min' and current < self.best) or \\\n",
" (self.mode == 'max' and current > self.best):\n",
" self.best = current\n",
" torch.save(learn.model.state_dict(), self.filepath)\n",
" print(f\"Saved model checkpoint to {self.filepath}\")\n",
"\n",
"\n"
]
},
{
Expand Down
59 changes: 40 additions & 19 deletions rapidai/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,11 @@
'rapidai.learner.DeviceCB.__init__': ('learner.html#devicecb.__init__', 'rapidai/learner.py'),
'rapidai.learner.DeviceCB.before_batch': ('learner.html#devicecb.before_batch', 'rapidai/learner.py'),
'rapidai.learner.DeviceCB.before_fit': ('learner.html#devicecb.before_fit', 'rapidai/learner.py'),
'rapidai.learner.EarlyStoppingCB': ('learner.html#earlystoppingcb', 'rapidai/learner.py'),
'rapidai.learner.EarlyStoppingCB.__init__': ( 'learner.html#earlystoppingcb.__init__',
'rapidai/learner.py'),
'rapidai.learner.EarlyStoppingCB.after_epoch': ( 'learner.html#earlystoppingcb.after_epoch',
'rapidai/learner.py'),
'rapidai.learner.LRFinderCB': ('learner.html#lrfindercb', 'rapidai/learner.py'),
'rapidai.learner.LRFinderCB.__init__': ('learner.html#lrfindercb.__init__', 'rapidai/learner.py'),
'rapidai.learner.LRFinderCB.after_batch': ('learner.html#lrfindercb.after_batch', 'rapidai/learner.py'),
Expand All @@ -180,23 +185,26 @@
'rapidai.learner.Learner.fit': ('learner.html#learner.fit', 'rapidai/learner.py'),
'rapidai.learner.Learner.one_epoch': ('learner.html#learner.one_epoch', 'rapidai/learner.py'),
'rapidai.learner.Learner.training': ('learner.html#learner.training', 'rapidai/learner.py'),
'rapidai.learner.MLflowCallback': ('learner.html#mlflowcallback', 'rapidai/learner.py'),
'rapidai.learner.MLflowCallback.__init__': ('learner.html#mlflowcallback.__init__', 'rapidai/learner.py'),
'rapidai.learner.MLflowCallback.after_batch': ( 'learner.html#mlflowcallback.after_batch',
'rapidai/learner.py'),
'rapidai.learner.MLflowCallback.after_epoch': ( 'learner.html#mlflowcallback.after_epoch',
'rapidai/learner.py'),
'rapidai.learner.MLflowCallback.after_fit': ( 'learner.html#mlflowcallback.after_fit',
'rapidai/learner.py'),
'rapidai.learner.MLflowCallback.before_fit': ( 'learner.html#mlflowcallback.before_fit',
'rapidai/learner.py'),
'rapidai.learner.MLflowCB': ('learner.html#mlflowcb', 'rapidai/learner.py'),
'rapidai.learner.MLflowCB.__init__': ('learner.html#mlflowcb.__init__', 'rapidai/learner.py'),
'rapidai.learner.MLflowCB.after_batch': ('learner.html#mlflowcb.after_batch', 'rapidai/learner.py'),
'rapidai.learner.MLflowCB.after_epoch': ('learner.html#mlflowcb.after_epoch', 'rapidai/learner.py'),
'rapidai.learner.MLflowCB.after_fit': ('learner.html#mlflowcb.after_fit', 'rapidai/learner.py'),
'rapidai.learner.MLflowCB.before_fit': ('learner.html#mlflowcb.before_fit', 'rapidai/learner.py'),
'rapidai.learner.MetricsCB': ('learner.html#metricscb', 'rapidai/learner.py'),
'rapidai.learner.MetricsCB.__init__': ('learner.html#metricscb.__init__', 'rapidai/learner.py'),
'rapidai.learner.MetricsCB._log': ('learner.html#metricscb._log', 'rapidai/learner.py'),
'rapidai.learner.MetricsCB.after_batch': ('learner.html#metricscb.after_batch', 'rapidai/learner.py'),
'rapidai.learner.MetricsCB.after_epoch': ('learner.html#metricscb.after_epoch', 'rapidai/learner.py'),
'rapidai.learner.MetricsCB.before_epoch': ('learner.html#metricscb.before_epoch', 'rapidai/learner.py'),
'rapidai.learner.MetricsCB.before_fit': ('learner.html#metricscb.before_fit', 'rapidai/learner.py'),
'rapidai.learner.ModelCheckpointCB': ('learner.html#modelcheckpointcb', 'rapidai/learner.py'),
'rapidai.learner.ModelCheckpointCB.__init__': ( 'learner.html#modelcheckpointcb.__init__',
'rapidai/learner.py'),
'rapidai.learner.ModelCheckpointCB.after_epoch': ( 'learner.html#modelcheckpointcb.after_epoch',
'rapidai/learner.py'),
'rapidai.learner.ModelCheckpointCB.before_fit': ( 'learner.html#modelcheckpointcb.before_fit',
'rapidai/learner.py'),
'rapidai.learner.MomentumLearner': ('learner.html#momentumlearner', 'rapidai/learner.py'),
'rapidai.learner.MomentumLearner.__init__': ( 'learner.html#momentumlearner.__init__',
'rapidai/learner.py'),
Expand All @@ -209,9 +217,25 @@
'rapidai.learner.ProgressCB.after_epoch': ('learner.html#progresscb.after_epoch', 'rapidai/learner.py'),
'rapidai.learner.ProgressCB.before_epoch': ('learner.html#progresscb.before_epoch', 'rapidai/learner.py'),
'rapidai.learner.ProgressCB.before_fit': ('learner.html#progresscb.before_fit', 'rapidai/learner.py'),
'rapidai.learner.ReduceLROnPlateauCB': ('learner.html#reducelronplateaucb', 'rapidai/learner.py'),
'rapidai.learner.ReduceLROnPlateauCB.__init__': ( 'learner.html#reducelronplateaucb.__init__',
'rapidai/learner.py'),
'rapidai.learner.ReduceLROnPlateauCB.after_epoch': ( 'learner.html#reducelronplateaucb.after_epoch',
'rapidai/learner.py'),
'rapidai.learner.ReduceLROnPlateauCB.before_fit': ( 'learner.html#reducelronplateaucb.before_fit',
'rapidai/learner.py'),
'rapidai.learner.SingleBatchCB': ('learner.html#singlebatchcb', 'rapidai/learner.py'),
'rapidai.learner.SingleBatchCB.after_batch': ( 'learner.html#singlebatchcb.after_batch',
'rapidai/learner.py'),
'rapidai.learner.TensorBoardCB': ('learner.html#tensorboardcb', 'rapidai/learner.py'),
'rapidai.learner.TensorBoardCB.__init__': ('learner.html#tensorboardcb.__init__', 'rapidai/learner.py'),
'rapidai.learner.TensorBoardCB.after_batch': ( 'learner.html#tensorboardcb.after_batch',
'rapidai/learner.py'),
'rapidai.learner.TensorBoardCB.after_epoch': ( 'learner.html#tensorboardcb.after_epoch',
'rapidai/learner.py'),
'rapidai.learner.TensorBoardCB.after_fit': ('learner.html#tensorboardcb.after_fit', 'rapidai/learner.py'),
'rapidai.learner.TensorBoardCB.before_fit': ( 'learner.html#tensorboardcb.before_fit',
'rapidai/learner.py'),
'rapidai.learner.TrainCB': ('learner.html#traincb', 'rapidai/learner.py'),
'rapidai.learner.TrainCB.__init__': ('learner.html#traincb.__init__', 'rapidai/learner.py'),
'rapidai.learner.TrainCB.backward': ('learner.html#traincb.backward', 'rapidai/learner.py'),
Expand All @@ -225,15 +249,12 @@
'rapidai.learner.TrainLearner.predict': ('learner.html#trainlearner.predict', 'rapidai/learner.py'),
'rapidai.learner.TrainLearner.step': ('learner.html#trainlearner.step', 'rapidai/learner.py'),
'rapidai.learner.TrainLearner.zero_grad': ('learner.html#trainlearner.zero_grad', 'rapidai/learner.py'),
'rapidai.learner.WandbCallback': ('learner.html#wandbcallback', 'rapidai/learner.py'),
'rapidai.learner.WandbCallback.__init__': ('learner.html#wandbcallback.__init__', 'rapidai/learner.py'),
'rapidai.learner.WandbCallback.after_batch': ( 'learner.html#wandbcallback.after_batch',
'rapidai/learner.py'),
'rapidai.learner.WandbCallback.after_epoch': ( 'learner.html#wandbcallback.after_epoch',
'rapidai/learner.py'),
'rapidai.learner.WandbCallback.after_fit': ('learner.html#wandbcallback.after_fit', 'rapidai/learner.py'),
'rapidai.learner.WandbCallback.before_fit': ( 'learner.html#wandbcallback.before_fit',
'rapidai/learner.py'),
'rapidai.learner.WandbCB': ('learner.html#wandbcb', 'rapidai/learner.py'),
'rapidai.learner.WandbCB.__init__': ('learner.html#wandbcb.__init__', 'rapidai/learner.py'),
'rapidai.learner.WandbCB.after_batch': ('learner.html#wandbcb.after_batch', 'rapidai/learner.py'),
'rapidai.learner.WandbCB.after_epoch': ('learner.html#wandbcb.after_epoch', 'rapidai/learner.py'),
'rapidai.learner.WandbCB.after_fit': ('learner.html#wandbcb.after_fit', 'rapidai/learner.py'),
'rapidai.learner.WandbCB.before_fit': ('learner.html#wandbcb.before_fit', 'rapidai/learner.py'),
'rapidai.learner.lr_find': ('learner.html#lr_find', 'rapidai/learner.py'),
'rapidai.learner.run_cbs': ('learner.html#run_cbs', 'rapidai/learner.py'),
'rapidai.learner.to_cpu': ('learner.html#to_cpu', 'rapidai/learner.py'),
Expand Down
Loading

0 comments on commit 0e42928

Please sign in to comment.