Skip to content

Commit

Permalink
Added Wandb & Mlflow callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
m-gopichand committed Aug 23, 2024
1 parent ef837d6 commit acfc649
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 6 deletions.
3 changes: 2 additions & 1 deletion nbs/03_convolutions.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1526,7 +1526,8 @@
"def_device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'\n",
"\n",
"def device():\n",
" return 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'\n",
" return 'Accelerator: Apple-GPU' if torch.backends.mps.is_available() else 'Accelerator: GPU' if torch.cuda.is_available() else 'No accelerator found using cpu'\n",
" \n",
" \n",
"def to_device(x, device=def_device):\n",
" if isinstance(x, torch.Tensor): return x.to(device)\n",
Expand Down
63 changes: 62 additions & 1 deletion nbs/04_learner.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
"from operator import attrgetter\n",
"from functools import partial\n",
"from copy import copy\n",
"import wandb\n",
"import mlflow\n",
"\n",
"from torch import optim\n",
"import torch.nn.functional as F\n",
Expand Down Expand Up @@ -946,7 +948,7 @@
"source": [
"#|export\n",
"class Learner():\n",
" def __init__(self, model, dls=(0,), loss_func=F.mse_loss, lr=0.1, cbs=None, opt_func=optim.SGD):\n",
" def __init__(self, model, dls=(0,), loss_func=F.mse_loss, lr=0.1, cbs=None, opt_func=optim.Adam):\n",
" cbs = fc.L(cbs)\n",
" fc.store_attr()\n",
"\n",
Expand Down Expand Up @@ -1391,6 +1393,65 @@
"MomentumLearner(get_model(), dls, F.cross_entropy, cbs=cbs).lr_find()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "42db89ae",
"metadata": {},
"outputs": [],
"source": [
"#|export\n",
"class WandbCallback(Callback):\n",
" def __init__(self, project_name, run_name=None, config=None):\n",
" self.project_name = project_name\n",
" self.run_name = run_name\n",
" self.config = config\n",
"\n",
" def before_fit(self, learn):\n",
" wandb.init(project=self.project_name, name=self.run_name, config=self.config)\n",
" learn.wandb_run = wandb.run\n",
"\n",
" def after_batch(self, learn):\n",
" wandb.log({\"train/loss\": learn.loss.item(), \"train/epoch\": learn.epoch, \"train/iter\": learn.iter})\n",
"\n",
" def after_epoch(self, learn):\n",
" metrics = {f\"train/{k}\": v.compute().item() for k, v in learn.metrics.metrics.items()}\n",
" metrics.update({f\"val/{k}\": v.compute().item() for k, v in learn.metrics.metrics.items()})\n",
" wandb.log(metrics)\n",
"\n",
" def after_fit(self, learn):\n",
" wandb.finish()\n",
"\n",
"\n",
"\n",
"class MLflowCallback(Callback):\n",
" def __init__(self, experiment_name=None, run_name=None, tracking_uri=None, config=None):\n",
" self.experiment_name = experiment_name\n",
" self.run_name = run_name\n",
" self.tracking_uri = tracking_uri\n",
" self.config = config\n",
"\n",
" def before_fit(self, learn):\n",
" if self.tracking_uri:\n",
" mlflow.set_tracking_uri(self.tracking_uri)\n",
" if self.experiment_name:\n",
" mlflow.set_experiment(self.experiment_name)\n",
" self.run = mlflow.start_run(run_name=self.run_name)\n",
" if self.config:\n",
" mlflow.log_params(self.config)\n",
"\n",
" def after_batch(self, learn):\n",
" mlflow.log_metric(\"train/loss\", learn.loss.item(), step=learn.iter)\n",
"\n",
" def after_epoch(self, learn):\n",
" metrics = {f\"train/{k}\": v.compute().item() for k, v in learn.metrics.metrics.items()}\n",
" mlflow.log_metrics(metrics, step=learn.epoch)\n",
"\n",
" def after_fit(self, learn):\n",
" mlflow.end_run()\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
19 changes: 19 additions & 0 deletions rapidai/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,16 @@
'rapidai.learner.Learner.fit': ('learner.html#learner.fit', 'rapidai/learner.py'),
'rapidai.learner.Learner.one_epoch': ('learner.html#learner.one_epoch', 'rapidai/learner.py'),
'rapidai.learner.Learner.training': ('learner.html#learner.training', 'rapidai/learner.py'),
'rapidai.learner.MLflowCallback': ('learner.html#mlflowcallback', 'rapidai/learner.py'),
'rapidai.learner.MLflowCallback.__init__': ('learner.html#mlflowcallback.__init__', 'rapidai/learner.py'),
'rapidai.learner.MLflowCallback.after_batch': ( 'learner.html#mlflowcallback.after_batch',
'rapidai/learner.py'),
'rapidai.learner.MLflowCallback.after_epoch': ( 'learner.html#mlflowcallback.after_epoch',
'rapidai/learner.py'),
'rapidai.learner.MLflowCallback.after_fit': ( 'learner.html#mlflowcallback.after_fit',
'rapidai/learner.py'),
'rapidai.learner.MLflowCallback.before_fit': ( 'learner.html#mlflowcallback.before_fit',
'rapidai/learner.py'),
'rapidai.learner.MetricsCB': ('learner.html#metricscb', 'rapidai/learner.py'),
'rapidai.learner.MetricsCB.__init__': ('learner.html#metricscb.__init__', 'rapidai/learner.py'),
'rapidai.learner.MetricsCB._log': ('learner.html#metricscb._log', 'rapidai/learner.py'),
Expand Down Expand Up @@ -215,6 +225,15 @@
'rapidai.learner.TrainLearner.predict': ('learner.html#trainlearner.predict', 'rapidai/learner.py'),
'rapidai.learner.TrainLearner.step': ('learner.html#trainlearner.step', 'rapidai/learner.py'),
'rapidai.learner.TrainLearner.zero_grad': ('learner.html#trainlearner.zero_grad', 'rapidai/learner.py'),
'rapidai.learner.WandbCallback': ('learner.html#wandbcallback', 'rapidai/learner.py'),
'rapidai.learner.WandbCallback.__init__': ('learner.html#wandbcallback.__init__', 'rapidai/learner.py'),
'rapidai.learner.WandbCallback.after_batch': ( 'learner.html#wandbcallback.after_batch',
'rapidai/learner.py'),
'rapidai.learner.WandbCallback.after_epoch': ( 'learner.html#wandbcallback.after_epoch',
'rapidai/learner.py'),
'rapidai.learner.WandbCallback.after_fit': ('learner.html#wandbcallback.after_fit', 'rapidai/learner.py'),
'rapidai.learner.WandbCallback.before_fit': ( 'learner.html#wandbcallback.before_fit',
'rapidai/learner.py'),
'rapidai.learner.lr_find': ('learner.html#lr_find', 'rapidai/learner.py'),
'rapidai.learner.run_cbs': ('learner.html#run_cbs', 'rapidai/learner.py'),
'rapidai.learner.to_cpu': ('learner.html#to_cpu', 'rapidai/learner.py'),
Expand Down
3 changes: 2 additions & 1 deletion rapidai/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def conv(ni, nf, ks=3, stride=2, act=True):
def_device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'

def device():
return 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'
return 'Accelerator: Apple-GPU' if torch.backends.mps.is_available() else 'Accelerator: GPU' if torch.cuda.is_available() else 'No accelerator found using cpu'


def to_device(x, device=def_device):
if isinstance(x, torch.Tensor): return x.to(device)
Expand Down
58 changes: 56 additions & 2 deletions rapidai/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# %% auto 0
__all__ = ['CancelFitException', 'CancelBatchException', 'CancelEpochException', 'Callback', 'run_cbs', 'SingleBatchCB', 'to_cpu',
'MetricsCB', 'DeviceCB', 'TrainCB', 'ProgressCB', 'with_cbs', 'Learner', 'TrainLearner', 'MomentumLearner',
'LRFinderCB', 'lr_find']
'LRFinderCB', 'lr_find', 'WandbCallback', 'MLflowCallback']

# %% ../nbs/04_learner.ipynb 1
import math,torch,matplotlib.pyplot as plt
Expand All @@ -12,6 +12,8 @@
from operator import attrgetter
from functools import partial
from copy import copy
import wandb
import mlflow

from torch import optim
import torch.nn.functional as F
Expand Down Expand Up @@ -134,7 +136,7 @@ def _f(o, *args, **kwargs):

# %% ../nbs/04_learner.ipynb 49
class Learner():
def __init__(self, model, dls=(0,), loss_func=F.mse_loss, lr=0.1, cbs=None, opt_func=optim.SGD):
def __init__(self, model, dls=(0,), loss_func=F.mse_loss, lr=0.1, cbs=None, opt_func=optim.Adam):
cbs = fc.L(cbs)
fc.store_attr()

Expand Down Expand Up @@ -236,3 +238,55 @@ def cleanup_fit(self, learn):
@fc.patch
def lr_find(self:Learner, gamma=1.3, max_mult=3, start_lr=1e-5, max_epochs=10):
self.fit(max_epochs, lr=start_lr, cbs=LRFinderCB(gamma=gamma, max_mult=max_mult))

# %% ../nbs/04_learner.ipynb 65
class WandbCallback(Callback):
def __init__(self, project_name, run_name=None, config=None):
self.project_name = project_name
self.run_name = run_name
self.config = config

def before_fit(self, learn):
wandb.init(project=self.project_name, name=self.run_name, config=self.config)
learn.wandb_run = wandb.run

def after_batch(self, learn):
wandb.log({"train/loss": learn.loss.item(), "train/epoch": learn.epoch, "train/iter": learn.iter})

def after_epoch(self, learn):
metrics = {f"train/{k}": v.compute().item() for k, v in learn.metrics.metrics.items()}
metrics.update({f"val/{k}": v.compute().item() for k, v in learn.metrics.metrics.items()})
wandb.log(metrics)

def after_fit(self, learn):
wandb.finish()



class MLflowCallback(Callback):
def __init__(self, experiment_name=None, run_name=None, tracking_uri=None, config=None):
self.experiment_name = experiment_name
self.run_name = run_name
self.tracking_uri = tracking_uri
self.config = config

def before_fit(self, learn):
if self.tracking_uri:
mlflow.set_tracking_uri(self.tracking_uri)
if self.experiment_name:
mlflow.set_experiment(self.experiment_name)
self.run = mlflow.start_run(run_name=self.run_name)
if self.config:
mlflow.log_params(self.config)

def after_batch(self, learn):
mlflow.log_metric("train/loss", learn.loss.item(), step=learn.iter)

def after_epoch(self, learn):
metrics = {f"train/{k}": v.compute().item() for k, v in learn.metrics.metrics.items()}
mlflow.log_metrics(metrics, step=learn.epoch)

def after_fit(self, learn):
mlflow.end_run()


2 changes: 1 addition & 1 deletion settings.ini
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ clean_ids = True
clear_all = False
cell_number = True
requirements = matplotlib datasets fastprogress fastcore
pip_requirements = torch>=1.7,<=2.4 torcheval diffusers einops timm
pip_requirements = torch>=1.7,<=2.4 torcheval diffusers einops timm wandb mlflow
conda_requirements = pytorch>=1.7,<1.14
dev_requirements = nbdev

0 comments on commit acfc649

Please sign in to comment.