From a67b00cc7fd96fcd3933eb56e6e991e5cb281b3c Mon Sep 17 00:00:00 2001 From: Gopichand Madala Date: Sat, 24 Aug 2024 10:41:30 +0530 Subject: [PATCH] Bugfixes --- rapidai/_modidx.py | 7 ------- rapidai/learner.py | 37 +------------------------------------ 2 files changed, 1 insertion(+), 43 deletions(-) diff --git a/rapidai/_modidx.py b/rapidai/_modidx.py index 18c32e8..a37789b 100644 --- a/rapidai/_modidx.py +++ b/rapidai/_modidx.py @@ -220,13 +220,6 @@ 'rapidai.learner.TrainLearner.predict': ('learner.html#trainlearner.predict', 'rapidai/learner.py'), 'rapidai.learner.TrainLearner.step': ('learner.html#trainlearner.step', 'rapidai/learner.py'), 'rapidai.learner.TrainLearner.zero_grad': ('learner.html#trainlearner.zero_grad', 'rapidai/learner.py'), - 'rapidai.learner.WandbCallback': ('learner.html#wandbcallback', 'rapidai/learner.py'), - 'rapidai.learner.WandbCallback.__init__': ('learner.html#wandbcallback.__init__', 'rapidai/learner.py'), - 'rapidai.learner.WandbCallback.after_epoch': ( 'learner.html#wandbcallback.after_epoch', - 'rapidai/learner.py'), - 'rapidai.learner.WandbCallback.after_fit': ('learner.html#wandbcallback.after_fit', 'rapidai/learner.py'), - 'rapidai.learner.WandbCallback.before_fit': ( 'learner.html#wandbcallback.before_fit', - 'rapidai/learner.py'), 'rapidai.learner.lr_find': ('learner.html#lr_find', 'rapidai/learner.py'), 'rapidai.learner.run_cbs': ('learner.html#run_cbs', 'rapidai/learner.py'), 'rapidai.learner.to_cpu': ('learner.html#to_cpu', 'rapidai/learner.py'), diff --git a/rapidai/learner.py b/rapidai/learner.py index eee34e7..bcfbe87 100755 --- a/rapidai/learner.py +++ b/rapidai/learner.py @@ -3,7 +3,7 @@ # %% auto 0 __all__ = ['CancelFitException', 'CancelBatchException', 'CancelEpochException', 'Callback', 'run_cbs', 'SingleBatchCB', 'to_cpu', 'MetricsCB', 'DeviceCB', 'TrainCB', 'ProgressCB', 'with_cbs', 'Learner', 'TrainLearner', 'MomentumLearner', - 'LRFinderCB', 'lr_find', 'WandbCallback', 'EarlyStoppingCallback'] + 'LRFinderCB', 'lr_find', 'EarlyStoppingCallback'] # %% ../nbs/04_learner.ipynb 1 import math,torch,matplotlib.pyplot as plt @@ -240,41 +240,6 @@ def lr_find(self:Learner, gamma=1.3, max_mult=3, start_lr=1e-5, max_epochs=10): self.fit(max_epochs, lr=start_lr, cbs=LRFinderCB(gamma=gamma, max_mult=max_mult)) # %% ../nbs/04_learner.ipynb 65 -class WandbCallback(Callback): - order = MetricsCB.order+1 - def __init__(self, project_name, name=None): - self.project_name = project_name - self.entity = name - self.best_loss = float('inf') - self.best_model_state = None - - def before_fit(self, learn): - wandb.init(project=self.project_name, entity=self.entity) - wandb.config.update({ - 'learning_rate': learn.lr, - 'epochs': learn.n_epochs, - 'batch_size': len(learn.dls.train.batch_size) - }) - - def after_epoch(self, learn): - metrics = {k: v.compute() for k, v in learn.metrics.all_metrics.items()} - metrics['epoch'] = learn.epoch - metrics['train'] = 'train' if learn.model.training else 'valid' - wandb.log(metrics) - - # Save the best model based on loss - current_loss = learn.metrics.all_metrics['loss'].compute() - if learn.model.training and current_loss < self.best_loss: - self.best_loss = current_loss - self.best_model_state = learn.model.state_dict() - wandb.save('model.pth') # Save model to wandb - - def after_fit(self, learn): - if self.best_model_state is not None: - torch.save(self.best_model_state, 'best_model.pth') - wandb.save('best_model.pth') - - class EarlyStoppingCallback(Callback): order = MetricsCB.order+1 def __init__(self, patience=3, min_delta=0.001):