diff --git a/nbs/04_learner.ipynb b/nbs/04_learner.ipynb index afd021a..46845d8 100755 --- a/nbs/04_learner.ipynb +++ b/nbs/04_learner.ipynb @@ -1415,7 +1415,7 @@ ], "source": [ "#|export\n", - "class EarlyStoppingCallback(Callback):\n", + "class EarlyStoppingCB(Callback):\n", " order = MetricsCB.order+1\n", " def __init__(self, patience=3, min_delta=0.001):\n", " self.patience = patience\n", @@ -1434,7 +1434,42 @@ " \n", " if self.wait >= self.patience:\n", " print(\"Early stopping triggered.\")\n", - " raise CancelFitException" + " raise CancelFitException\n", + " \n", + "class WandbCB(Callback):\n", + " order = MetricsCB.order+1\n", + "\n", + " def __init__(self, project_name=\"default-project\", run_name=None, log_model=True, model_name=\"best_model.pth\"):\n", + " self.project_name = project_name\n", + " self.run_name = run_name\n", + " self.log_model = log_model\n", + " self.model_name = model_name\n", + " self.best_loss = float('inf')\n", + "\n", + " def before_fit(self, learn):\n", + " self.run = wandb.init(project=self.project_name, name=self.run_name)\n", + " self.run.watch(learn.model, log=\"all\")\n", + "\n", + " def after_epoch(self, learn):\n", + " metrics = {k: v.compute().item() if hasattr(v, 'compute') else v for k, v in learn.metrics.all_metrics.items()}\n", + " metrics[\"epoch\"] = learn.epoch\n", + " self.run.log(metrics)\n", + "\n", + " # Save the model only if it has the best validation loss so far\n", + " if not learn.model.training:\n", + " current_loss = learn.metrics.all_metrics['loss'].compute()\n", + " if current_loss < self.best_loss:\n", + " self.best_loss = current_loss\n", + " torch.save(learn.model.state_dict(), self.model_name)\n", + " wandb.save(self.model_name)\n", + "\n", + " def after_fit(self, learn):\n", + " self.run.finish()\n", + "\n", + " def cleanup_fit(self, learn):\n", + " if self.log_model:\n", + " # Log the best model\n", + " wandb.log_artifact(self.model_name, type=\"model\")\n" ] }, { diff --git a/rapidai/_modidx.py b/rapidai/_modidx.py index a37789b..dbc8e7e 100644 --- a/rapidai/_modidx.py +++ b/rapidai/_modidx.py @@ -165,11 +165,11 @@ 'rapidai.learner.DeviceCB.__init__': ('learner.html#devicecb.__init__', 'rapidai/learner.py'), 'rapidai.learner.DeviceCB.before_batch': ('learner.html#devicecb.before_batch', 'rapidai/learner.py'), 'rapidai.learner.DeviceCB.before_fit': ('learner.html#devicecb.before_fit', 'rapidai/learner.py'), - 'rapidai.learner.EarlyStoppingCallback': ('learner.html#earlystoppingcallback', 'rapidai/learner.py'), - 'rapidai.learner.EarlyStoppingCallback.__init__': ( 'learner.html#earlystoppingcallback.__init__', - 'rapidai/learner.py'), - 'rapidai.learner.EarlyStoppingCallback.after_epoch': ( 'learner.html#earlystoppingcallback.after_epoch', - 'rapidai/learner.py'), + 'rapidai.learner.EarlyStoppingCB': ('learner.html#earlystoppingcb', 'rapidai/learner.py'), + 'rapidai.learner.EarlyStoppingCB.__init__': ( 'learner.html#earlystoppingcb.__init__', + 'rapidai/learner.py'), + 'rapidai.learner.EarlyStoppingCB.after_epoch': ( 'learner.html#earlystoppingcb.after_epoch', + 'rapidai/learner.py'), 'rapidai.learner.LRFinderCB': ('learner.html#lrfindercb', 'rapidai/learner.py'), 'rapidai.learner.LRFinderCB.__init__': ('learner.html#lrfindercb.__init__', 'rapidai/learner.py'), 'rapidai.learner.LRFinderCB.after_batch': ('learner.html#lrfindercb.after_batch', 'rapidai/learner.py'), @@ -220,6 +220,12 @@ 'rapidai.learner.TrainLearner.predict': ('learner.html#trainlearner.predict', 'rapidai/learner.py'), 'rapidai.learner.TrainLearner.step': ('learner.html#trainlearner.step', 'rapidai/learner.py'), 'rapidai.learner.TrainLearner.zero_grad': ('learner.html#trainlearner.zero_grad', 'rapidai/learner.py'), + 'rapidai.learner.WandbCB': ('learner.html#wandbcb', 'rapidai/learner.py'), + 'rapidai.learner.WandbCB.__init__': ('learner.html#wandbcb.__init__', 'rapidai/learner.py'), + 'rapidai.learner.WandbCB.after_epoch': ('learner.html#wandbcb.after_epoch', 'rapidai/learner.py'), + 'rapidai.learner.WandbCB.after_fit': ('learner.html#wandbcb.after_fit', 'rapidai/learner.py'), + 'rapidai.learner.WandbCB.before_fit': ('learner.html#wandbcb.before_fit', 'rapidai/learner.py'), + 'rapidai.learner.WandbCB.cleanup_fit': ('learner.html#wandbcb.cleanup_fit', 'rapidai/learner.py'), 'rapidai.learner.lr_find': ('learner.html#lr_find', 'rapidai/learner.py'), 'rapidai.learner.run_cbs': ('learner.html#run_cbs', 'rapidai/learner.py'), 'rapidai.learner.to_cpu': ('learner.html#to_cpu', 'rapidai/learner.py'), diff --git a/rapidai/learner.py b/rapidai/learner.py index bcfbe87..2bc3bd1 100755 --- a/rapidai/learner.py +++ b/rapidai/learner.py @@ -3,7 +3,7 @@ # %% auto 0 __all__ = ['CancelFitException', 'CancelBatchException', 'CancelEpochException', 'Callback', 'run_cbs', 'SingleBatchCB', 'to_cpu', 'MetricsCB', 'DeviceCB', 'TrainCB', 'ProgressCB', 'with_cbs', 'Learner', 'TrainLearner', 'MomentumLearner', - 'LRFinderCB', 'lr_find', 'EarlyStoppingCallback'] + 'LRFinderCB', 'lr_find', 'EarlyStoppingCB', 'WandbCB'] # %% ../nbs/04_learner.ipynb 1 import math,torch,matplotlib.pyplot as plt @@ -240,7 +240,7 @@ def lr_find(self:Learner, gamma=1.3, max_mult=3, start_lr=1e-5, max_epochs=10): self.fit(max_epochs, lr=start_lr, cbs=LRFinderCB(gamma=gamma, max_mult=max_mult)) # %% ../nbs/04_learner.ipynb 65 -class EarlyStoppingCallback(Callback): +class EarlyStoppingCB(Callback): order = MetricsCB.order+1 def __init__(self, patience=3, min_delta=0.001): self.patience = patience @@ -260,3 +260,39 @@ def after_epoch(self, learn): if self.wait >= self.patience: print("Early stopping triggered.") raise CancelFitException + +class WandbCB(Callback): + order = MetricsCB.order+1 + + def __init__(self, project_name="default-project", run_name=None, log_model=True, model_name="best_model.pth"): + self.project_name = project_name + self.run_name = run_name + self.log_model = log_model + self.model_name = model_name + self.best_loss = float('inf') + + def before_fit(self, learn): + self.run = wandb.init(project=self.project_name, name=self.run_name) + self.run.watch(learn.model, log="all") + + def after_epoch(self, learn): + metrics = {k: v.compute().item() if hasattr(v, 'compute') else v for k, v in learn.metrics.all_metrics.items()} + metrics["epoch"] = learn.epoch + self.run.log(metrics) + + # Save the model only if it has the best validation loss so far + if not learn.model.training: + current_loss = learn.metrics.all_metrics['loss'].compute() + if current_loss < self.best_loss: + self.best_loss = current_loss + torch.save(learn.model.state_dict(), self.model_name) + wandb.save(self.model_name) + + def after_fit(self, learn): + self.run.finish() + + def cleanup_fit(self, learn): + if self.log_model: + # Log the best model + wandb.log_artifact(self.model_name, type="model") +