Skip to content

Commit

Permalink
Merge pull request #1 from m-gopichand/dev
Browse files Browse the repository at this point in the history
Bug Fixes , Added WandB and Earlystopping Callbacks in learner
  • Loading branch information
m-gopichand authored Aug 24, 2024
2 parents ef837d6 + 7bbf9eb commit 11dd8a6
Show file tree
Hide file tree
Showing 7 changed files with 182 additions and 24 deletions.
3 changes: 2 additions & 1 deletion nbs/03_convolutions.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1526,7 +1526,8 @@
"def_device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'\n",
"\n",
"def device():\n",
" return 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'\n",
" return 'Accelerator: Apple-GPU' if torch.backends.mps.is_available() else 'Accelerator: GPU' if torch.cuda.is_available() else 'No accelerator found using cpu'\n",
" \n",
" \n",
"def to_device(x, device=def_device):\n",
" if isinstance(x, torch.Tensor): return x.to(device)\n",
Expand Down
107 changes: 93 additions & 14 deletions nbs/04_learner.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,15 @@
"from operator import attrgetter\n",
"from functools import partial\n",
"from copy import copy\n",
"\n",
"import wandb\n",
"import mlflow\n",
"import mlflow.pytorch\n",
"from pathlib import Path\n",
"from torch import optim\n",
"import torch.nn.functional as F\n",
"\n",
"from rapidai.conv import *\n",
"\n",
"from fastprogress import progress_bar,master_bar"
"from fastprogress import progress_bar,master_bar\n",
"from torch.utils.tensorboard import SummaryWriter"
]
},
{
Expand Down Expand Up @@ -670,9 +672,10 @@
"outputs": [],
"source": [
"class Learner():\n",
" def __init__(self, model, dls=(0,), loss_func=F.mse_loss, lr=0.1, cbs=None, opt_func=optim.SGD):\n",
" def __init__(self, model, dls=(0,), loss_func=F.mse_loss, lr=0.1, cbs=None, opt_func=optim.Adam):\n",
" cbs = fc.L(cbs)\n",
" fc.store_attr()\n",
" \n",
"\n",
" @contextmanager\n",
" def cb_ctx(self, nm):\n",
Expand All @@ -695,6 +698,7 @@
" self.backward()\n",
" self.step()\n",
" self.zero_grad()\n",
" \n",
" \n",
" def fit(self, n_epochs=1, train=True, valid=True, cbs=None, lr=None):\n",
" cbs = fc.L(cbs)\n",
Expand Down Expand Up @@ -946,7 +950,7 @@
"source": [
"#|export\n",
"class Learner():\n",
" def __init__(self, model, dls=(0,), loss_func=F.mse_loss, lr=0.1, cbs=None, opt_func=optim.SGD):\n",
" def __init__(self, model, dls=(0,), loss_func=F.mse_loss, lr=0.1, cbs=None, opt_func=optim.Adam):\n",
" cbs = fc.L(cbs)\n",
" fc.store_attr()\n",
"\n",
Expand Down Expand Up @@ -1391,6 +1395,89 @@
"MomentumLearner(get_model(), dls, F.cross_entropy, cbs=cbs).lr_find()"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "42db89ae",
"metadata": {},
"outputs": [
{
"ename": "NameError",
"evalue": "name 'Callback' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[1], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m#\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m \u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mWandbCB\u001b[39;00m(\u001b[43mCallback\u001b[49m):\n\u001b[1;32m 3\u001b[0m order \u001b[38;5;241m=\u001b[39m MetricsCB\u001b[38;5;241m.\u001b[39morder \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28mself\u001b[39m, project_name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdefault-project\u001b[39m\u001b[38;5;124m\"\u001b[39m, run_name\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, log_model\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, save_model\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, model_name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfinal_model.pth\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n",
"\u001b[0;31mNameError\u001b[0m: name 'Callback' is not defined"
]
}
],
"source": [
"#|export\n",
"class WandbCB(Callback):\n",
" order = MetricsCB.order + 1\n",
" \n",
" def __init__(self, project_name=\"default-project\", run_name=None, log_model=True, save_model=False, model_name=\"final_model.pth\"):\n",
" self.project_name = project_name\n",
" self.run_name = run_name\n",
" self.log_model = log_model\n",
" self.save_model = save_model\n",
" self.model_name = model_name\n",
"\n",
" def before_fit(self, learn):\n",
" self.run = wandb.init(project=self.project_name, name=self.run_name)\n",
" self.run.watch(learn.model, log=\"all\")\n",
"\n",
" def after_epoch(self, learn):\n",
" # Log all metrics for the epoch\n",
" metrics = {f'{k}_train' if learn.training else f'{k}_valid': v.compute().item() if hasattr(v, 'compute') else v \n",
" for k, v in learn.metrics.all_metrics.items()}\n",
" metrics[\"epoch\"] = learn.epoch\n",
" self.run.log(metrics)\n",
"\n",
" # Save model checkpoint if save_model is enabled\n",
" if self.save_model and not learn.training:\n",
" model_path = f\"{self.model_name}_epoch_{learn.epoch}.pth\"\n",
" torch.save(learn.model.state_dict(), model_path)\n",
" wandb.save(model_path)\n",
"\n",
" def after_fit(self, learn):\n",
" if self.log_model:\n",
" torch.save(learn.model.state_dict(), self.model_name)\n",
" wandb.log_artifact(self.model_name, type=\"model\")\n",
" \n",
"\n",
" def cleanup_fit(self, learn):\n",
" if self.log_model:\n",
" torch.save(learn.model.state_dict(), self.model_name)\n",
" wandb.save(self.model_name)\n",
" self.run.finish()\n",
"\n",
"\n",
"class EarlyStoppingCB(Callback):\n",
" order = MetricsCB.order+1\n",
" def __init__(self, patience=3, min_delta=0.001):\n",
" self.patience = patience\n",
" self.min_delta = min_delta\n",
" self.best_loss = float('inf')\n",
" self.wait = 0\n",
"\n",
" def after_epoch(self, learn):\n",
" if not learn.model.training:\n",
" current_loss = learn.metrics.all_metrics['loss'].compute()\n",
" if current_loss < self.best_loss - self.min_delta:\n",
" self.best_loss = current_loss\n",
" self.wait = 0\n",
" else:\n",
" self.wait += 1\n",
" \n",
" if self.wait >= self.patience:\n",
" print(\"Early stopping triggered.\")\n",
" raise CancelFitException\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -1419,14 +1506,6 @@
"source": [
"import nbdev; nbdev.nbdev_export()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0fc774ac",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
2 changes: 1 addition & 1 deletion rapidai/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.0.4"
__version__ = "0.0.7"
11 changes: 11 additions & 0 deletions rapidai/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,11 @@
'rapidai.learner.DeviceCB.__init__': ('learner.html#devicecb.__init__', 'rapidai/learner.py'),
'rapidai.learner.DeviceCB.before_batch': ('learner.html#devicecb.before_batch', 'rapidai/learner.py'),
'rapidai.learner.DeviceCB.before_fit': ('learner.html#devicecb.before_fit', 'rapidai/learner.py'),
'rapidai.learner.EarlyStoppingCB': ('learner.html#earlystoppingcb', 'rapidai/learner.py'),
'rapidai.learner.EarlyStoppingCB.__init__': ( 'learner.html#earlystoppingcb.__init__',
'rapidai/learner.py'),
'rapidai.learner.EarlyStoppingCB.after_epoch': ( 'learner.html#earlystoppingcb.after_epoch',
'rapidai/learner.py'),
'rapidai.learner.LRFinderCB': ('learner.html#lrfindercb', 'rapidai/learner.py'),
'rapidai.learner.LRFinderCB.__init__': ('learner.html#lrfindercb.__init__', 'rapidai/learner.py'),
'rapidai.learner.LRFinderCB.after_batch': ('learner.html#lrfindercb.after_batch', 'rapidai/learner.py'),
Expand Down Expand Up @@ -215,6 +220,12 @@
'rapidai.learner.TrainLearner.predict': ('learner.html#trainlearner.predict', 'rapidai/learner.py'),
'rapidai.learner.TrainLearner.step': ('learner.html#trainlearner.step', 'rapidai/learner.py'),
'rapidai.learner.TrainLearner.zero_grad': ('learner.html#trainlearner.zero_grad', 'rapidai/learner.py'),
'rapidai.learner.WandbCB': ('learner.html#wandbcb', 'rapidai/learner.py'),
'rapidai.learner.WandbCB.__init__': ('learner.html#wandbcb.__init__', 'rapidai/learner.py'),
'rapidai.learner.WandbCB.after_epoch': ('learner.html#wandbcb.after_epoch', 'rapidai/learner.py'),
'rapidai.learner.WandbCB.after_fit': ('learner.html#wandbcb.after_fit', 'rapidai/learner.py'),
'rapidai.learner.WandbCB.before_fit': ('learner.html#wandbcb.before_fit', 'rapidai/learner.py'),
'rapidai.learner.WandbCB.cleanup_fit': ('learner.html#wandbcb.cleanup_fit', 'rapidai/learner.py'),
'rapidai.learner.lr_find': ('learner.html#lr_find', 'rapidai/learner.py'),
'rapidai.learner.run_cbs': ('learner.html#run_cbs', 'rapidai/learner.py'),
'rapidai.learner.to_cpu': ('learner.html#to_cpu', 'rapidai/learner.py'),
Expand Down
3 changes: 2 additions & 1 deletion rapidai/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def conv(ni, nf, ks=3, stride=2, act=True):
def_device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'

def device():
return 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'
return 'Accelerator: Apple-GPU' if torch.backends.mps.is_available() else 'Accelerator: GPU' if torch.cuda.is_available() else 'No accelerator found using cpu'


def to_device(x, device=def_device):
if isinstance(x, torch.Tensor): return x.to(device)
Expand Down
76 changes: 71 additions & 5 deletions rapidai/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# %% auto 0
__all__ = ['CancelFitException', 'CancelBatchException', 'CancelEpochException', 'Callback', 'run_cbs', 'SingleBatchCB', 'to_cpu',
'MetricsCB', 'DeviceCB', 'TrainCB', 'ProgressCB', 'with_cbs', 'Learner', 'TrainLearner', 'MomentumLearner',
'LRFinderCB', 'lr_find']
'LRFinderCB', 'lr_find', 'WandbCB', 'EarlyStoppingCB']

# %% ../nbs/04_learner.ipynb 1
import math,torch,matplotlib.pyplot as plt
Expand All @@ -12,13 +12,15 @@
from operator import attrgetter
from functools import partial
from copy import copy

import wandb
import mlflow
import mlflow.pytorch
from pathlib import Path
from torch import optim
import torch.nn.functional as F

from .conv import *

from fastprogress import progress_bar,master_bar
from torch.utils.tensorboard import SummaryWriter

# %% ../nbs/04_learner.ipynb 14
class CancelFitException(Exception): pass
Expand Down Expand Up @@ -134,7 +136,7 @@ def _f(o, *args, **kwargs):

# %% ../nbs/04_learner.ipynb 49
class Learner():
def __init__(self, model, dls=(0,), loss_func=F.mse_loss, lr=0.1, cbs=None, opt_func=optim.SGD):
def __init__(self, model, dls=(0,), loss_func=F.mse_loss, lr=0.1, cbs=None, opt_func=optim.Adam):
cbs = fc.L(cbs)
fc.store_attr()

Expand Down Expand Up @@ -236,3 +238,67 @@ def cleanup_fit(self, learn):
@fc.patch
def lr_find(self:Learner, gamma=1.3, max_mult=3, start_lr=1e-5, max_epochs=10):
self.fit(max_epochs, lr=start_lr, cbs=LRFinderCB(gamma=gamma, max_mult=max_mult))

# %% ../nbs/04_learner.ipynb 65
class WandbCB(Callback):
order = MetricsCB.order + 1

def __init__(self, project_name="default-project", run_name=None, log_model=True, save_model=False, model_name="final_model.pth"):
self.project_name = project_name
self.run_name = run_name
self.log_model = log_model
self.save_model = save_model
self.model_name = model_name

def before_fit(self, learn):
self.run = wandb.init(project=self.project_name, name=self.run_name)
self.run.watch(learn.model, log="all")

def after_epoch(self, learn):
# Log all metrics for the epoch
metrics = {f'{k}_train' if learn.training else f'{k}_valid': v.compute().item() if hasattr(v, 'compute') else v
for k, v in learn.metrics.all_metrics.items()}
metrics["epoch"] = learn.epoch
self.run.log(metrics)

# Save model checkpoint if save_model is enabled
if self.save_model and not learn.training:
model_path = f"{self.model_name}_epoch_{learn.epoch}.pth"
torch.save(learn.model.state_dict(), model_path)
wandb.save(model_path)

def after_fit(self, learn):
if self.log_model:
torch.save(learn.model.state_dict(), self.model_name)
wandb.log_artifact(self.model_name, type="model")


def cleanup_fit(self, learn):
if self.log_model:
torch.save(learn.model.state_dict(), self.model_name)
wandb.save(self.model_name)
self.run.finish()


class EarlyStoppingCB(Callback):
order = MetricsCB.order+1
def __init__(self, patience=3, min_delta=0.001):
self.patience = patience
self.min_delta = min_delta
self.best_loss = float('inf')
self.wait = 0

def after_epoch(self, learn):
if not learn.model.training:
current_loss = learn.metrics.all_metrics['loss'].compute()
if current_loss < self.best_loss - self.min_delta:
self.best_loss = current_loss
self.wait = 0
else:
self.wait += 1

if self.wait >= self.patience:
print("Early stopping triggered.")
raise CancelFitException


4 changes: 2 additions & 2 deletions settings.ini
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[DEFAULT]
repo = rapidai
lib_name = rapidai
version = 0.0.4
version = 0.0.7
min_python = 3.9
license = apache2
black_formatting = False
Expand Down Expand Up @@ -34,7 +34,7 @@ clean_ids = True
clear_all = False
cell_number = True
requirements = matplotlib datasets fastprogress fastcore
pip_requirements = torch>=1.7,<=2.4 torcheval diffusers einops timm
pip_requirements = torch>=1.7,<=2.4 torcheval diffusers einops timm wandb mlflow accelerate tensorboard
conda_requirements = pytorch>=1.7,<1.14
dev_requirements = nbdev

0 comments on commit 11dd8a6

Please sign in to comment.