Skip to content

Commit

Permalink
Added more functionalities for w&b and mlflow
Browse files Browse the repository at this point in the history
  • Loading branch information
m-gopichand committed Aug 23, 2024
1 parent acfc649 commit a11b754
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 35 deletions.
204 changes: 171 additions & 33 deletions nbs/04_learner.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,11 @@
"from copy import copy\n",
"import wandb\n",
"import mlflow\n",
"\n",
"import mlflow.pytorch\n",
"from pathlib import Path\n",
"from torch import optim\n",
"import torch.nn.functional as F\n",
"\n",
"from rapidai.conv import *\n",
"\n",
"from fastprogress import progress_bar,master_bar"
]
},
Expand Down Expand Up @@ -1402,54 +1401,201 @@
"source": [
"#|export\n",
"class WandbCallback(Callback):\n",
" def __init__(self, project_name, run_name=None, config=None):\n",
" self.project_name = project_name\n",
" self.run_name = run_name\n",
" self.config = config\n",
" def __init__(\n",
" self,\n",
" project_name: str,\n",
" run_name: str = None,\n",
" config: dict = None,\n",
" log_model: bool = True,\n",
" log_frequency: int = 100,\n",
" save_best_model: bool = True,\n",
" monitor: str = 'val_loss',\n",
" mode: str = 'min',\n",
" save_model_checkpoint: bool = True,\n",
" checkpoint_dir: str = './models',\n",
" log_gradients: bool = False,\n",
" log_preds: bool = False,\n",
" preds_frequency: int = 500,\n",
" ):\n",
" \"\"\"\n",
" Initializes the WandbCallback.\n",
"\n",
" Args:\n",
" project_name (str): Name of the W&B project.\n",
" run_name (str, optional): Name of the W&B run. Defaults to None.\n",
" config (dict, optional): Hyperparameters and configurations. Defaults to None.\n",
" log_model (bool, optional): Log model architecture to W&B. Defaults to True.\n",
" log_frequency (int, optional): Frequency (in steps) to log training metrics. Defaults to 100.\n",
" save_best_model (bool, optional): Save the best model during training. Defaults to True.\n",
" monitor (str, optional): Metric to monitor for best model saving. Defaults to 'val_loss'.\n",
" mode (str, optional): 'min' or 'max' to minimize or maximize the monitored metric. Defaults to 'min'.\n",
" save_model_checkpoint (bool, optional): Save model checkpoint after each epoch. Defaults to True.\n",
" checkpoint_dir (str, optional): Directory to save model checkpoints. Defaults to './models'.\n",
" log_gradients (bool, optional): Log gradients histograms. Defaults to False.\n",
" log_preds (bool, optional): Log model predictions. Defaults to False.\n",
" preds_frequency (int, optional): Frequency (in steps) to log predictions. Defaults to 500.\n",
" \"\"\"\n",
" fc.store_attr()\n",
"\n",
" def before_fit(self, learn):\n",
" wandb.init(project=self.project_name, name=self.run_name, config=self.config)\n",
" learn.wandb_run = wandb.run\n",
" # Initialize W&B run\n",
" self.run = wandb.init(project=self.project_name, name=self.run_name, config=self.config)\n",
" self.best_metric = None\n",
" self.operator = torch.lt if self.mode == 'min' else torch.gt\n",
" self.checkpoint_path = Path(self.checkpoint_dir)\n",
" self.checkpoint_path.mkdir(parents=True, exist_ok=True)\n",
" \n",
" if self.log_model:\n",
" # Log model architecture\n",
" wandb.watch(learn.model, log='all' if self.log_gradients else 'parameters')\n",
"\n",
" def after_batch(self, learn):\n",
" wandb.log({\"train/loss\": learn.loss.item(), \"train/epoch\": learn.epoch, \"train/iter\": learn.iter})\n",
" if learn.training and (learn.iter % self.log_frequency == 0):\n",
" metrics = {\n",
" 'train/loss': learn.loss.item(),\n",
" 'train/epoch': learn.epoch + (learn.iter / len(learn.dl))\n",
" }\n",
" # Log metrics to W&B\n",
" wandb.log(metrics, step=learn.iter_total)\n",
" \n",
" if self.log_gradients:\n",
" # Log gradients\n",
" for name, param in learn.model.named_parameters():\n",
" if param.grad is not None:\n",
" wandb.log({f'gradients/{name}': wandb.Histogram(param.grad.cpu().numpy())}, step=learn.iter_total)\n",
" \n",
" if self.log_preds and (learn.iter % self.preds_frequency == 0):\n",
" # Log predictions (assuming classification task)\n",
" inputs, targets = learn.batch[:2]\n",
" preds = torch.argmax(learn.preds, dim=1)\n",
" table = wandb.Table(columns=[\"input\", \"prediction\", \"target\"])\n",
" for inp, pred, target in zip(inputs.cpu(), preds.cpu(), targets.cpu()):\n",
" table.add_data(wandb.Image(inp), pred.item(), target.item())\n",
" wandb.log({\"predictions\": table}, step=learn.iter_total)\n",
"\n",
" def after_epoch(self, learn):\n",
" metrics = {f\"train/{k}\": v.compute().item() for k, v in learn.metrics.metrics.items()}\n",
" metrics.update({f\"val/{k}\": v.compute().item() for k, v in learn.metrics.metrics.items()})\n",
" wandb.log(metrics)\n",
" # Compute validation metrics\n",
" val_metrics = {f'val/{k}': v.compute().item() for k, v in learn.metrics.metrics.items()}\n",
" val_metrics['val/loss'] = learn.metrics.loss.compute().item()\n",
" val_metrics['epoch'] = learn.epoch\n",
" # Log validation metrics\n",
" wandb.log(val_metrics, step=learn.iter_total)\n",
" \n",
" # Save model checkpoint\n",
" if self.save_model_checkpoint:\n",
" epoch_checkpoint_path = self.checkpoint_path / f'model_epoch_{learn.epoch}.pth'\n",
" torch.save(learn.model.state_dict(), epoch_checkpoint_path)\n",
" wandb.save(str(epoch_checkpoint_path))\n",
" \n",
" # Save best model\n",
" current_metric = val_metrics.get(f'val/{self.monitor}', val_metrics.get('val/loss'))\n",
" if self.save_best_model and current_metric is not None:\n",
" if self.best_metric is None or self.operator(current_metric, self.best_metric):\n",
" self.best_metric = current_metric\n",
" best_checkpoint_path = self.checkpoint_path / 'best_model.pth'\n",
" torch.save(learn.model.state_dict(), best_checkpoint_path)\n",
" wandb.save(str(best_checkpoint_path))\n",
" wandb.run.summary[f'best_{self.monitor}'] = self.best_metric\n",
"\n",
" def after_fit(self, learn):\n",
" # Finish W&B run\n",
" wandb.finish()\n",
"\n",
"\n",
"\n",
"\n",
"class MLflowCallback(Callback):\n",
" def __init__(self, experiment_name=None, run_name=None, tracking_uri=None, config=None):\n",
" self.experiment_name = experiment_name\n",
" self.run_name = run_name\n",
" self.tracking_uri = tracking_uri\n",
" self.config = config\n",
" def __init__(\n",
" self,\n",
" experiment_name: str,\n",
" run_name: str = None,\n",
" tracking_uri: str = None,\n",
" config: dict = None,\n",
" log_model: bool = True,\n",
" log_frequency: int = 100,\n",
" save_best_model: bool = True,\n",
" monitor: str = 'val_loss',\n",
" mode: str = 'min',\n",
" save_model_checkpoint: bool = True,\n",
" checkpoint_dir: str = './models',\n",
" ):\n",
" \"\"\"\n",
" Initializes the MLflowCallback.\n",
"\n",
" Args:\n",
" experiment_name (str): Name of the MLflow experiment.\n",
" run_name (str, optional): Name of the MLflow run. Defaults to None.\n",
" tracking_uri (str, optional): URI of the tracking server. Defaults to None (local server).\n",
" config (dict, optional): Hyperparameters and configurations. Defaults to None.\n",
" log_model (bool, optional): Log model architecture to MLflow. Defaults to True.\n",
" log_frequency (int, optional): Frequency (in steps) to log training metrics. Defaults to 100.\n",
" save_best_model (bool, optional): Save the best model during training. Defaults to True.\n",
" monitor (str, optional): Metric to monitor for best model saving. Defaults to 'val_loss'.\n",
" mode (str, optional): 'min' or 'max' to minimize or maximize the monitored metric. Defaults to 'min'.\n",
" save_model_checkpoint (bool, optional): Save model checkpoint after each epoch. Defaults to True.\n",
" checkpoint_dir (str, optional): Directory to save model checkpoints. Defaults to './models'.\n",
" \"\"\"\n",
" fc.store_attr()\n",
"\n",
" def before_fit(self, learn):\n",
" # Set MLflow tracking URI if provided\n",
" if self.tracking_uri:\n",
" mlflow.set_tracking_uri(self.tracking_uri)\n",
" if self.experiment_name:\n",
" mlflow.set_experiment(self.experiment_name)\n",
" \n",
" # Set experiment and run\n",
" mlflow.set_experiment(self.experiment_name)\n",
" self.run = mlflow.start_run(run_name=self.run_name)\n",
" \n",
" # Log configuration parameters\n",
" if self.config:\n",
" mlflow.log_params(self.config)\n",
" \n",
" # Prepare for saving checkpoints\n",
" self.best_metric = None\n",
" self.operator = torch.lt if self.mode == 'min' else torch.gt\n",
" self.checkpoint_path = Path(self.checkpoint_dir)\n",
" self.checkpoint_path.mkdir(parents=True, exist_ok=True)\n",
" \n",
" # Log model architecture if needed\n",
" if self.log_model:\n",
" mlflow.pytorch.log_model(learn.model, \"model_architecture\")\n",
"\n",
" def after_batch(self, learn):\n",
" mlflow.log_metric(\"train/loss\", learn.loss.item(), step=learn.iter)\n",
" if learn.training and (learn.iter % self.log_frequency == 0):\n",
" metrics = {\n",
" 'train/loss': learn.loss.item(),\n",
" 'train/epoch': learn.epoch + (learn.iter / len(learn.dl))\n",
" }\n",
" mlflow.log_metrics(metrics, step=learn.iter_total)\n",
"\n",
" def after_epoch(self, learn):\n",
" metrics = {f\"train/{k}\": v.compute().item() for k, v in learn.metrics.metrics.items()}\n",
" mlflow.log_metrics(metrics, step=learn.epoch)\n",
" # Compute validation metrics\n",
" val_metrics = {f'val/{k}': v.compute().item() for k, v in learn.metrics.metrics.items()}\n",
" val_metrics['val/loss'] = learn.metrics.loss.compute().item()\n",
" val_metrics['epoch'] = learn.epoch\n",
" \n",
" # Log validation metrics\n",
" mlflow.log_metrics(val_metrics, step=learn.iter_total)\n",
" \n",
" # Save model checkpoint\n",
" if self.save_model_checkpoint:\n",
" epoch_checkpoint_path = self.checkpoint_path / f'model_epoch_{learn.epoch}.pth'\n",
" torch.save(learn.model.state_dict(), epoch_checkpoint_path)\n",
" mlflow.log_artifact(str(epoch_checkpoint_path))\n",
" \n",
" # Save best model\n",
" current_metric = val_metrics.get(f'val/{self.monitor}', val_metrics.get('val/loss'))\n",
" if self.save_best_model and current_metric is not None:\n",
" if self.best_metric is None or self.operator(current_metric, self.best_metric):\n",
" self.best_metric = current_metric\n",
" best_checkpoint_path = self.checkpoint_path / 'best_model.pth'\n",
" torch.save(learn.model.state_dict(), best_checkpoint_path)\n",
" mlflow.log_artifact(str(best_checkpoint_path))\n",
" mlflow.log_metric(f'best_{self.monitor}', self.best_metric, step=learn.iter_total)\n",
"\n",
" def after_fit(self, learn):\n",
" mlflow.end_run()\n",
"\n"
" # Finish MLflow run\n",
" mlflow.end_run()"
]
},
{
Expand Down Expand Up @@ -1480,14 +1626,6 @@
"source": [
"import nbdev; nbdev.nbdev_export()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0fc774ac",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
2 changes: 1 addition & 1 deletion rapidai/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.0.4"
__version__ = "0.0.5"
2 changes: 1 addition & 1 deletion settings.ini
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[DEFAULT]
repo = rapidai
lib_name = rapidai
version = 0.0.4
version = 0.0.5
min_python = 3.9
license = apache2
black_formatting = False
Expand Down

0 comments on commit a11b754

Please sign in to comment.