diff --git a/nbs/04_learner.ipynb b/nbs/04_learner.ipynb index 08e06be..31a51ca 100755 --- a/nbs/04_learner.ipynb +++ b/nbs/04_learner.ipynb @@ -671,7 +671,7 @@ "outputs": [], "source": [ "class Learner():\n", - " def __init__(self, model, dls=(0,), loss_func=F.mse_loss, lr=0.1, cbs=None, opt_func=optim.SGD):\n", + " def __init__(self, model, dls=(0,), loss_func=F.mse_loss, lr=0.1, cbs=None, opt_func=optim.Adam):\n", " cbs = fc.L(cbs)\n", " fc.store_attr()\n", "\n", diff --git a/rapidai/__init__.py b/rapidai/__init__.py index b1a19e3..034f46c 100755 --- a/rapidai/__init__.py +++ b/rapidai/__init__.py @@ -1 +1 @@ -__version__ = "0.0.5" +__version__ = "0.0.6" diff --git a/rapidai/learner.py b/rapidai/learner.py index e58b2eb..a114db2 100755 --- a/rapidai/learner.py +++ b/rapidai/learner.py @@ -14,12 +14,11 @@ from copy import copy import wandb import mlflow - +import mlflow.pytorch +from pathlib import Path from torch import optim import torch.nn.functional as F - from .conv import * - from fastprogress import progress_bar,master_bar # %% ../nbs/04_learner.ipynb 14 @@ -241,52 +240,198 @@ def lr_find(self:Learner, gamma=1.3, max_mult=3, start_lr=1e-5, max_epochs=10): # %% ../nbs/04_learner.ipynb 65 class WandbCallback(Callback): - def __init__(self, project_name, run_name=None, config=None): - self.project_name = project_name - self.run_name = run_name - self.config = config + def __init__( + self, + project_name: str, + run_name: str = None, + config: dict = None, + log_model: bool = True, + log_frequency: int = 100, + save_best_model: bool = True, + monitor: str = 'val_loss', + mode: str = 'min', + save_model_checkpoint: bool = True, + checkpoint_dir: str = './models', + log_gradients: bool = False, + log_preds: bool = False, + preds_frequency: int = 500, + ): + """ + Initializes the WandbCallback. + + Args: + project_name (str): Name of the W&B project. + run_name (str, optional): Name of the W&B run. Defaults to None. + config (dict, optional): Hyperparameters and configurations. Defaults to None. + log_model (bool, optional): Log model architecture to W&B. Defaults to True. + log_frequency (int, optional): Frequency (in steps) to log training metrics. Defaults to 100. + save_best_model (bool, optional): Save the best model during training. Defaults to True. + monitor (str, optional): Metric to monitor for best model saving. Defaults to 'val_loss'. + mode (str, optional): 'min' or 'max' to minimize or maximize the monitored metric. Defaults to 'min'. + save_model_checkpoint (bool, optional): Save model checkpoint after each epoch. Defaults to True. + checkpoint_dir (str, optional): Directory to save model checkpoints. Defaults to './models'. + log_gradients (bool, optional): Log gradients histograms. Defaults to False. + log_preds (bool, optional): Log model predictions. Defaults to False. + preds_frequency (int, optional): Frequency (in steps) to log predictions. Defaults to 500. + """ + fc.store_attr() def before_fit(self, learn): - wandb.init(project=self.project_name, name=self.run_name, config=self.config) - learn.wandb_run = wandb.run + # Initialize W&B run + self.run = wandb.init(project=self.project_name, name=self.run_name, config=self.config) + self.best_metric = None + self.operator = torch.lt if self.mode == 'min' else torch.gt + self.checkpoint_path = Path(self.checkpoint_dir) + self.checkpoint_path.mkdir(parents=True, exist_ok=True) + + if self.log_model: + # Log model architecture + wandb.watch(learn.model, log='all' if self.log_gradients else 'parameters') def after_batch(self, learn): - wandb.log({"train/loss": learn.loss.item(), "train/epoch": learn.epoch, "train/iter": learn.iter}) + if learn.training and (learn.iter % self.log_frequency == 0): + metrics = { + 'train/loss': learn.loss.item(), + 'train/epoch': learn.epoch + (learn.iter / len(learn.dl)) + } + # Log metrics to W&B + wandb.log(metrics, step=learn.iter_total) + + if self.log_gradients: + # Log gradients + for name, param in learn.model.named_parameters(): + if param.grad is not None: + wandb.log({f'gradients/{name}': wandb.Histogram(param.grad.cpu().numpy())}, step=learn.iter_total) + + if self.log_preds and (learn.iter % self.preds_frequency == 0): + # Log predictions (assuming classification task) + inputs, targets = learn.batch[:2] + preds = torch.argmax(learn.preds, dim=1) + table = wandb.Table(columns=["input", "prediction", "target"]) + for inp, pred, target in zip(inputs.cpu(), preds.cpu(), targets.cpu()): + table.add_data(wandb.Image(inp), pred.item(), target.item()) + wandb.log({"predictions": table}, step=learn.iter_total) def after_epoch(self, learn): - metrics = {f"train/{k}": v.compute().item() for k, v in learn.metrics.metrics.items()} - metrics.update({f"val/{k}": v.compute().item() for k, v in learn.metrics.metrics.items()}) - wandb.log(metrics) + # Compute validation metrics + val_metrics = {f'val/{k}': v.compute().item() for k, v in learn.metrics.metrics.items()} + val_metrics['val/loss'] = learn.metrics.loss.compute().item() + val_metrics['epoch'] = learn.epoch + # Log validation metrics + wandb.log(val_metrics, step=learn.iter_total) + + # Save model checkpoint + if self.save_model_checkpoint: + epoch_checkpoint_path = self.checkpoint_path / f'model_epoch_{learn.epoch}.pth' + torch.save(learn.model.state_dict(), epoch_checkpoint_path) + wandb.save(str(epoch_checkpoint_path)) + + # Save best model + current_metric = val_metrics.get(f'val/{self.monitor}', val_metrics.get('val/loss')) + if self.save_best_model and current_metric is not None: + if self.best_metric is None or self.operator(current_metric, self.best_metric): + self.best_metric = current_metric + best_checkpoint_path = self.checkpoint_path / 'best_model.pth' + torch.save(learn.model.state_dict(), best_checkpoint_path) + wandb.save(str(best_checkpoint_path)) + wandb.run.summary[f'best_{self.monitor}'] = self.best_metric def after_fit(self, learn): + # Finish W&B run wandb.finish() + class MLflowCallback(Callback): - def __init__(self, experiment_name=None, run_name=None, tracking_uri=None, config=None): - self.experiment_name = experiment_name - self.run_name = run_name - self.tracking_uri = tracking_uri - self.config = config + def __init__( + self, + experiment_name: str, + run_name: str = None, + tracking_uri: str = None, + config: dict = None, + log_model: bool = True, + log_frequency: int = 100, + save_best_model: bool = True, + monitor: str = 'val_loss', + mode: str = 'min', + save_model_checkpoint: bool = True, + checkpoint_dir: str = './models', + ): + """ + Initializes the MLflowCallback. + + Args: + experiment_name (str): Name of the MLflow experiment. + run_name (str, optional): Name of the MLflow run. Defaults to None. + tracking_uri (str, optional): URI of the tracking server. Defaults to None (local server). + config (dict, optional): Hyperparameters and configurations. Defaults to None. + log_model (bool, optional): Log model architecture to MLflow. Defaults to True. + log_frequency (int, optional): Frequency (in steps) to log training metrics. Defaults to 100. + save_best_model (bool, optional): Save the best model during training. Defaults to True. + monitor (str, optional): Metric to monitor for best model saving. Defaults to 'val_loss'. + mode (str, optional): 'min' or 'max' to minimize or maximize the monitored metric. Defaults to 'min'. + save_model_checkpoint (bool, optional): Save model checkpoint after each epoch. Defaults to True. + checkpoint_dir (str, optional): Directory to save model checkpoints. Defaults to './models'. + """ + fc.store_attr() def before_fit(self, learn): + # Set MLflow tracking URI if provided if self.tracking_uri: mlflow.set_tracking_uri(self.tracking_uri) - if self.experiment_name: - mlflow.set_experiment(self.experiment_name) + + # Set experiment and run + mlflow.set_experiment(self.experiment_name) self.run = mlflow.start_run(run_name=self.run_name) + + # Log configuration parameters if self.config: mlflow.log_params(self.config) + + # Prepare for saving checkpoints + self.best_metric = None + self.operator = torch.lt if self.mode == 'min' else torch.gt + self.checkpoint_path = Path(self.checkpoint_dir) + self.checkpoint_path.mkdir(parents=True, exist_ok=True) + + # Log model architecture if needed + if self.log_model: + mlflow.pytorch.log_model(learn.model, "model_architecture") def after_batch(self, learn): - mlflow.log_metric("train/loss", learn.loss.item(), step=learn.iter) + if learn.training and (learn.iter % self.log_frequency == 0): + metrics = { + 'train/loss': learn.loss.item(), + 'train/epoch': learn.epoch + (learn.iter / len(learn.dl)) + } + mlflow.log_metrics(metrics, step=learn.iter_total) def after_epoch(self, learn): - metrics = {f"train/{k}": v.compute().item() for k, v in learn.metrics.metrics.items()} - mlflow.log_metrics(metrics, step=learn.epoch) + # Compute validation metrics + val_metrics = {f'val/{k}': v.compute().item() for k, v in learn.metrics.metrics.items()} + val_metrics['val/loss'] = learn.metrics.loss.compute().item() + val_metrics['epoch'] = learn.epoch + + # Log validation metrics + mlflow.log_metrics(val_metrics, step=learn.iter_total) + + # Save model checkpoint + if self.save_model_checkpoint: + epoch_checkpoint_path = self.checkpoint_path / f'model_epoch_{learn.epoch}.pth' + torch.save(learn.model.state_dict(), epoch_checkpoint_path) + mlflow.log_artifact(str(epoch_checkpoint_path)) + + # Save best model + current_metric = val_metrics.get(f'val/{self.monitor}', val_metrics.get('val/loss')) + if self.save_best_model and current_metric is not None: + if self.best_metric is None or self.operator(current_metric, self.best_metric): + self.best_metric = current_metric + best_checkpoint_path = self.checkpoint_path / 'best_model.pth' + torch.save(learn.model.state_dict(), best_checkpoint_path) + mlflow.log_artifact(str(best_checkpoint_path)) + mlflow.log_metric(f'best_{self.monitor}', self.best_metric, step=learn.iter_total) def after_fit(self, learn): + # Finish MLflow run mlflow.end_run() - - diff --git a/settings.ini b/settings.ini index 2d787a0..07054bf 100755 --- a/settings.ini +++ b/settings.ini @@ -1,7 +1,7 @@ [DEFAULT] repo = rapidai lib_name = rapidai -version = 0.0.5 +version = 0.0.6 min_python = 3.9 license = apache2 black_formatting = False