Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
m-gopichand committed Aug 23, 2024
1 parent a11b754 commit d53ed05
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 28 deletions.
2 changes: 1 addition & 1 deletion nbs/04_learner.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@
"outputs": [],
"source": [
"class Learner():\n",
" def __init__(self, model, dls=(0,), loss_func=F.mse_loss, lr=0.1, cbs=None, opt_func=optim.SGD):\n",
" def __init__(self, model, dls=(0,), loss_func=F.mse_loss, lr=0.1, cbs=None, opt_func=optim.Adam):\n",
" cbs = fc.L(cbs)\n",
" fc.store_attr()\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion rapidai/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.0.5"
__version__ = "0.0.6"
195 changes: 170 additions & 25 deletions rapidai/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@
from copy import copy
import wandb
import mlflow

import mlflow.pytorch
from pathlib import Path
from torch import optim
import torch.nn.functional as F

from .conv import *

from fastprogress import progress_bar,master_bar

# %% ../nbs/04_learner.ipynb 14
Expand Down Expand Up @@ -241,52 +240,198 @@ def lr_find(self:Learner, gamma=1.3, max_mult=3, start_lr=1e-5, max_epochs=10):

# %% ../nbs/04_learner.ipynb 65
class WandbCallback(Callback):
def __init__(self, project_name, run_name=None, config=None):
self.project_name = project_name
self.run_name = run_name
self.config = config
def __init__(
self,
project_name: str,
run_name: str = None,
config: dict = None,
log_model: bool = True,
log_frequency: int = 100,
save_best_model: bool = True,
monitor: str = 'val_loss',
mode: str = 'min',
save_model_checkpoint: bool = True,
checkpoint_dir: str = './models',
log_gradients: bool = False,
log_preds: bool = False,
preds_frequency: int = 500,
):
"""
Initializes the WandbCallback.
Args:
project_name (str): Name of the W&B project.
run_name (str, optional): Name of the W&B run. Defaults to None.
config (dict, optional): Hyperparameters and configurations. Defaults to None.
log_model (bool, optional): Log model architecture to W&B. Defaults to True.
log_frequency (int, optional): Frequency (in steps) to log training metrics. Defaults to 100.
save_best_model (bool, optional): Save the best model during training. Defaults to True.
monitor (str, optional): Metric to monitor for best model saving. Defaults to 'val_loss'.
mode (str, optional): 'min' or 'max' to minimize or maximize the monitored metric. Defaults to 'min'.
save_model_checkpoint (bool, optional): Save model checkpoint after each epoch. Defaults to True.
checkpoint_dir (str, optional): Directory to save model checkpoints. Defaults to './models'.
log_gradients (bool, optional): Log gradients histograms. Defaults to False.
log_preds (bool, optional): Log model predictions. Defaults to False.
preds_frequency (int, optional): Frequency (in steps) to log predictions. Defaults to 500.
"""
fc.store_attr()

def before_fit(self, learn):
wandb.init(project=self.project_name, name=self.run_name, config=self.config)
learn.wandb_run = wandb.run
# Initialize W&B run
self.run = wandb.init(project=self.project_name, name=self.run_name, config=self.config)
self.best_metric = None
self.operator = torch.lt if self.mode == 'min' else torch.gt
self.checkpoint_path = Path(self.checkpoint_dir)
self.checkpoint_path.mkdir(parents=True, exist_ok=True)

if self.log_model:
# Log model architecture
wandb.watch(learn.model, log='all' if self.log_gradients else 'parameters')

def after_batch(self, learn):
wandb.log({"train/loss": learn.loss.item(), "train/epoch": learn.epoch, "train/iter": learn.iter})
if learn.training and (learn.iter % self.log_frequency == 0):
metrics = {
'train/loss': learn.loss.item(),
'train/epoch': learn.epoch + (learn.iter / len(learn.dl))
}
# Log metrics to W&B
wandb.log(metrics, step=learn.iter_total)

if self.log_gradients:
# Log gradients
for name, param in learn.model.named_parameters():
if param.grad is not None:
wandb.log({f'gradients/{name}': wandb.Histogram(param.grad.cpu().numpy())}, step=learn.iter_total)

if self.log_preds and (learn.iter % self.preds_frequency == 0):
# Log predictions (assuming classification task)
inputs, targets = learn.batch[:2]
preds = torch.argmax(learn.preds, dim=1)
table = wandb.Table(columns=["input", "prediction", "target"])
for inp, pred, target in zip(inputs.cpu(), preds.cpu(), targets.cpu()):
table.add_data(wandb.Image(inp), pred.item(), target.item())
wandb.log({"predictions": table}, step=learn.iter_total)

def after_epoch(self, learn):
metrics = {f"train/{k}": v.compute().item() for k, v in learn.metrics.metrics.items()}
metrics.update({f"val/{k}": v.compute().item() for k, v in learn.metrics.metrics.items()})
wandb.log(metrics)
# Compute validation metrics
val_metrics = {f'val/{k}': v.compute().item() for k, v in learn.metrics.metrics.items()}
val_metrics['val/loss'] = learn.metrics.loss.compute().item()
val_metrics['epoch'] = learn.epoch
# Log validation metrics
wandb.log(val_metrics, step=learn.iter_total)

# Save model checkpoint
if self.save_model_checkpoint:
epoch_checkpoint_path = self.checkpoint_path / f'model_epoch_{learn.epoch}.pth'
torch.save(learn.model.state_dict(), epoch_checkpoint_path)
wandb.save(str(epoch_checkpoint_path))

# Save best model
current_metric = val_metrics.get(f'val/{self.monitor}', val_metrics.get('val/loss'))
if self.save_best_model and current_metric is not None:
if self.best_metric is None or self.operator(current_metric, self.best_metric):
self.best_metric = current_metric
best_checkpoint_path = self.checkpoint_path / 'best_model.pth'
torch.save(learn.model.state_dict(), best_checkpoint_path)
wandb.save(str(best_checkpoint_path))
wandb.run.summary[f'best_{self.monitor}'] = self.best_metric

def after_fit(self, learn):
# Finish W&B run
wandb.finish()




class MLflowCallback(Callback):
def __init__(self, experiment_name=None, run_name=None, tracking_uri=None, config=None):
self.experiment_name = experiment_name
self.run_name = run_name
self.tracking_uri = tracking_uri
self.config = config
def __init__(
self,
experiment_name: str,
run_name: str = None,
tracking_uri: str = None,
config: dict = None,
log_model: bool = True,
log_frequency: int = 100,
save_best_model: bool = True,
monitor: str = 'val_loss',
mode: str = 'min',
save_model_checkpoint: bool = True,
checkpoint_dir: str = './models',
):
"""
Initializes the MLflowCallback.
Args:
experiment_name (str): Name of the MLflow experiment.
run_name (str, optional): Name of the MLflow run. Defaults to None.
tracking_uri (str, optional): URI of the tracking server. Defaults to None (local server).
config (dict, optional): Hyperparameters and configurations. Defaults to None.
log_model (bool, optional): Log model architecture to MLflow. Defaults to True.
log_frequency (int, optional): Frequency (in steps) to log training metrics. Defaults to 100.
save_best_model (bool, optional): Save the best model during training. Defaults to True.
monitor (str, optional): Metric to monitor for best model saving. Defaults to 'val_loss'.
mode (str, optional): 'min' or 'max' to minimize or maximize the monitored metric. Defaults to 'min'.
save_model_checkpoint (bool, optional): Save model checkpoint after each epoch. Defaults to True.
checkpoint_dir (str, optional): Directory to save model checkpoints. Defaults to './models'.
"""
fc.store_attr()

def before_fit(self, learn):
# Set MLflow tracking URI if provided
if self.tracking_uri:
mlflow.set_tracking_uri(self.tracking_uri)
if self.experiment_name:
mlflow.set_experiment(self.experiment_name)

# Set experiment and run
mlflow.set_experiment(self.experiment_name)
self.run = mlflow.start_run(run_name=self.run_name)

# Log configuration parameters
if self.config:
mlflow.log_params(self.config)

# Prepare for saving checkpoints
self.best_metric = None
self.operator = torch.lt if self.mode == 'min' else torch.gt
self.checkpoint_path = Path(self.checkpoint_dir)
self.checkpoint_path.mkdir(parents=True, exist_ok=True)

# Log model architecture if needed
if self.log_model:
mlflow.pytorch.log_model(learn.model, "model_architecture")

def after_batch(self, learn):
mlflow.log_metric("train/loss", learn.loss.item(), step=learn.iter)
if learn.training and (learn.iter % self.log_frequency == 0):
metrics = {
'train/loss': learn.loss.item(),
'train/epoch': learn.epoch + (learn.iter / len(learn.dl))
}
mlflow.log_metrics(metrics, step=learn.iter_total)

def after_epoch(self, learn):
metrics = {f"train/{k}": v.compute().item() for k, v in learn.metrics.metrics.items()}
mlflow.log_metrics(metrics, step=learn.epoch)
# Compute validation metrics
val_metrics = {f'val/{k}': v.compute().item() for k, v in learn.metrics.metrics.items()}
val_metrics['val/loss'] = learn.metrics.loss.compute().item()
val_metrics['epoch'] = learn.epoch

# Log validation metrics
mlflow.log_metrics(val_metrics, step=learn.iter_total)

# Save model checkpoint
if self.save_model_checkpoint:
epoch_checkpoint_path = self.checkpoint_path / f'model_epoch_{learn.epoch}.pth'
torch.save(learn.model.state_dict(), epoch_checkpoint_path)
mlflow.log_artifact(str(epoch_checkpoint_path))

# Save best model
current_metric = val_metrics.get(f'val/{self.monitor}', val_metrics.get('val/loss'))
if self.save_best_model and current_metric is not None:
if self.best_metric is None or self.operator(current_metric, self.best_metric):
self.best_metric = current_metric
best_checkpoint_path = self.checkpoint_path / 'best_model.pth'
torch.save(learn.model.state_dict(), best_checkpoint_path)
mlflow.log_artifact(str(best_checkpoint_path))
mlflow.log_metric(f'best_{self.monitor}', self.best_metric, step=learn.iter_total)

def after_fit(self, learn):
# Finish MLflow run
mlflow.end_run()


2 changes: 1 addition & 1 deletion settings.ini
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[DEFAULT]
repo = rapidai
lib_name = rapidai
version = 0.0.5
version = 0.0.6
min_python = 3.9
license = apache2
black_formatting = False
Expand Down

0 comments on commit d53ed05

Please sign in to comment.