Skip to content

Commit

Permalink
Bug Fixes and Wandb callback
Browse files Browse the repository at this point in the history
  • Loading branch information
m-gopichand committed Aug 24, 2024
1 parent a67b00c commit 652a410
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 9 deletions.
39 changes: 37 additions & 2 deletions nbs/04_learner.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1415,7 +1415,7 @@
],
"source": [
"#|export\n",
"class EarlyStoppingCallback(Callback):\n",
"class EarlyStoppingCB(Callback):\n",
" order = MetricsCB.order+1\n",
" def __init__(self, patience=3, min_delta=0.001):\n",
" self.patience = patience\n",
Expand All @@ -1434,7 +1434,42 @@
" \n",
" if self.wait >= self.patience:\n",
" print(\"Early stopping triggered.\")\n",
" raise CancelFitException"
" raise CancelFitException\n",
" \n",
"class WandbCB(Callback):\n",
" order = MetricsCB.order+1\n",
"\n",
" def __init__(self, project_name=\"default-project\", run_name=None, log_model=True, model_name=\"best_model.pth\"):\n",
" self.project_name = project_name\n",
" self.run_name = run_name\n",
" self.log_model = log_model\n",
" self.model_name = model_name\n",
" self.best_loss = float('inf')\n",
"\n",
" def before_fit(self, learn):\n",
" self.run = wandb.init(project=self.project_name, name=self.run_name)\n",
" self.run.watch(learn.model, log=\"all\")\n",
"\n",
" def after_epoch(self, learn):\n",
" metrics = {k: v.compute().item() if hasattr(v, 'compute') else v for k, v in learn.metrics.all_metrics.items()}\n",
" metrics[\"epoch\"] = learn.epoch\n",
" self.run.log(metrics)\n",
"\n",
" # Save the model only if it has the best validation loss so far\n",
" if not learn.model.training:\n",
" current_loss = learn.metrics.all_metrics['loss'].compute()\n",
" if current_loss < self.best_loss:\n",
" self.best_loss = current_loss\n",
" torch.save(learn.model.state_dict(), self.model_name)\n",
" wandb.save(self.model_name)\n",
"\n",
" def after_fit(self, learn):\n",
" self.run.finish()\n",
"\n",
" def cleanup_fit(self, learn):\n",
" if self.log_model:\n",
" # Log the best model\n",
" wandb.log_artifact(self.model_name, type=\"model\")\n"
]
},
{
Expand Down
16 changes: 11 additions & 5 deletions rapidai/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,11 +165,11 @@
'rapidai.learner.DeviceCB.__init__': ('learner.html#devicecb.__init__', 'rapidai/learner.py'),
'rapidai.learner.DeviceCB.before_batch': ('learner.html#devicecb.before_batch', 'rapidai/learner.py'),
'rapidai.learner.DeviceCB.before_fit': ('learner.html#devicecb.before_fit', 'rapidai/learner.py'),
'rapidai.learner.EarlyStoppingCallback': ('learner.html#earlystoppingcallback', 'rapidai/learner.py'),
'rapidai.learner.EarlyStoppingCallback.__init__': ( 'learner.html#earlystoppingcallback.__init__',
'rapidai/learner.py'),
'rapidai.learner.EarlyStoppingCallback.after_epoch': ( 'learner.html#earlystoppingcallback.after_epoch',
'rapidai/learner.py'),
'rapidai.learner.EarlyStoppingCB': ('learner.html#earlystoppingcb', 'rapidai/learner.py'),
'rapidai.learner.EarlyStoppingCB.__init__': ( 'learner.html#earlystoppingcb.__init__',
'rapidai/learner.py'),
'rapidai.learner.EarlyStoppingCB.after_epoch': ( 'learner.html#earlystoppingcb.after_epoch',
'rapidai/learner.py'),
'rapidai.learner.LRFinderCB': ('learner.html#lrfindercb', 'rapidai/learner.py'),
'rapidai.learner.LRFinderCB.__init__': ('learner.html#lrfindercb.__init__', 'rapidai/learner.py'),
'rapidai.learner.LRFinderCB.after_batch': ('learner.html#lrfindercb.after_batch', 'rapidai/learner.py'),
Expand Down Expand Up @@ -220,6 +220,12 @@
'rapidai.learner.TrainLearner.predict': ('learner.html#trainlearner.predict', 'rapidai/learner.py'),
'rapidai.learner.TrainLearner.step': ('learner.html#trainlearner.step', 'rapidai/learner.py'),
'rapidai.learner.TrainLearner.zero_grad': ('learner.html#trainlearner.zero_grad', 'rapidai/learner.py'),
'rapidai.learner.WandbCB': ('learner.html#wandbcb', 'rapidai/learner.py'),
'rapidai.learner.WandbCB.__init__': ('learner.html#wandbcb.__init__', 'rapidai/learner.py'),
'rapidai.learner.WandbCB.after_epoch': ('learner.html#wandbcb.after_epoch', 'rapidai/learner.py'),
'rapidai.learner.WandbCB.after_fit': ('learner.html#wandbcb.after_fit', 'rapidai/learner.py'),
'rapidai.learner.WandbCB.before_fit': ('learner.html#wandbcb.before_fit', 'rapidai/learner.py'),
'rapidai.learner.WandbCB.cleanup_fit': ('learner.html#wandbcb.cleanup_fit', 'rapidai/learner.py'),
'rapidai.learner.lr_find': ('learner.html#lr_find', 'rapidai/learner.py'),
'rapidai.learner.run_cbs': ('learner.html#run_cbs', 'rapidai/learner.py'),
'rapidai.learner.to_cpu': ('learner.html#to_cpu', 'rapidai/learner.py'),
Expand Down
40 changes: 38 additions & 2 deletions rapidai/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# %% auto 0
__all__ = ['CancelFitException', 'CancelBatchException', 'CancelEpochException', 'Callback', 'run_cbs', 'SingleBatchCB', 'to_cpu',
'MetricsCB', 'DeviceCB', 'TrainCB', 'ProgressCB', 'with_cbs', 'Learner', 'TrainLearner', 'MomentumLearner',
'LRFinderCB', 'lr_find', 'EarlyStoppingCallback']
'LRFinderCB', 'lr_find', 'EarlyStoppingCB', 'WandbCB']

# %% ../nbs/04_learner.ipynb 1
import math,torch,matplotlib.pyplot as plt
Expand Down Expand Up @@ -240,7 +240,7 @@ def lr_find(self:Learner, gamma=1.3, max_mult=3, start_lr=1e-5, max_epochs=10):
self.fit(max_epochs, lr=start_lr, cbs=LRFinderCB(gamma=gamma, max_mult=max_mult))

# %% ../nbs/04_learner.ipynb 65
class EarlyStoppingCallback(Callback):
class EarlyStoppingCB(Callback):
order = MetricsCB.order+1
def __init__(self, patience=3, min_delta=0.001):
self.patience = patience
Expand All @@ -260,3 +260,39 @@ def after_epoch(self, learn):
if self.wait >= self.patience:
print("Early stopping triggered.")
raise CancelFitException

class WandbCB(Callback):
order = MetricsCB.order+1

def __init__(self, project_name="default-project", run_name=None, log_model=True, model_name="best_model.pth"):
self.project_name = project_name
self.run_name = run_name
self.log_model = log_model
self.model_name = model_name
self.best_loss = float('inf')

def before_fit(self, learn):
self.run = wandb.init(project=self.project_name, name=self.run_name)
self.run.watch(learn.model, log="all")

def after_epoch(self, learn):
metrics = {k: v.compute().item() if hasattr(v, 'compute') else v for k, v in learn.metrics.all_metrics.items()}
metrics["epoch"] = learn.epoch
self.run.log(metrics)

# Save the model only if it has the best validation loss so far
if not learn.model.training:
current_loss = learn.metrics.all_metrics['loss'].compute()
if current_loss < self.best_loss:
self.best_loss = current_loss
torch.save(learn.model.state_dict(), self.model_name)
wandb.save(self.model_name)

def after_fit(self, learn):
self.run.finish()

def cleanup_fit(self, learn):
if self.log_model:
# Log the best model
wandb.log_artifact(self.model_name, type="model")

0 comments on commit 652a410

Please sign in to comment.