Skip to content

Commit

Permalink
Bugfixes
Browse files Browse the repository at this point in the history
  • Loading branch information
m-gopichand committed Aug 24, 2024
1 parent ac448ee commit a67b00c
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 43 deletions.
7 changes: 0 additions & 7 deletions rapidai/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,13 +220,6 @@
'rapidai.learner.TrainLearner.predict': ('learner.html#trainlearner.predict', 'rapidai/learner.py'),
'rapidai.learner.TrainLearner.step': ('learner.html#trainlearner.step', 'rapidai/learner.py'),
'rapidai.learner.TrainLearner.zero_grad': ('learner.html#trainlearner.zero_grad', 'rapidai/learner.py'),
'rapidai.learner.WandbCallback': ('learner.html#wandbcallback', 'rapidai/learner.py'),
'rapidai.learner.WandbCallback.__init__': ('learner.html#wandbcallback.__init__', 'rapidai/learner.py'),
'rapidai.learner.WandbCallback.after_epoch': ( 'learner.html#wandbcallback.after_epoch',
'rapidai/learner.py'),
'rapidai.learner.WandbCallback.after_fit': ('learner.html#wandbcallback.after_fit', 'rapidai/learner.py'),
'rapidai.learner.WandbCallback.before_fit': ( 'learner.html#wandbcallback.before_fit',
'rapidai/learner.py'),
'rapidai.learner.lr_find': ('learner.html#lr_find', 'rapidai/learner.py'),
'rapidai.learner.run_cbs': ('learner.html#run_cbs', 'rapidai/learner.py'),
'rapidai.learner.to_cpu': ('learner.html#to_cpu', 'rapidai/learner.py'),
Expand Down
37 changes: 1 addition & 36 deletions rapidai/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# %% auto 0
__all__ = ['CancelFitException', 'CancelBatchException', 'CancelEpochException', 'Callback', 'run_cbs', 'SingleBatchCB', 'to_cpu',
'MetricsCB', 'DeviceCB', 'TrainCB', 'ProgressCB', 'with_cbs', 'Learner', 'TrainLearner', 'MomentumLearner',
'LRFinderCB', 'lr_find', 'WandbCallback', 'EarlyStoppingCallback']
'LRFinderCB', 'lr_find', 'EarlyStoppingCallback']

# %% ../nbs/04_learner.ipynb 1
import math,torch,matplotlib.pyplot as plt
Expand Down Expand Up @@ -240,41 +240,6 @@ def lr_find(self:Learner, gamma=1.3, max_mult=3, start_lr=1e-5, max_epochs=10):
self.fit(max_epochs, lr=start_lr, cbs=LRFinderCB(gamma=gamma, max_mult=max_mult))

# %% ../nbs/04_learner.ipynb 65
class WandbCallback(Callback):
order = MetricsCB.order+1
def __init__(self, project_name, name=None):
self.project_name = project_name
self.entity = name
self.best_loss = float('inf')
self.best_model_state = None

def before_fit(self, learn):
wandb.init(project=self.project_name, entity=self.entity)
wandb.config.update({
'learning_rate': learn.lr,
'epochs': learn.n_epochs,
'batch_size': len(learn.dls.train.batch_size)
})

def after_epoch(self, learn):
metrics = {k: v.compute() for k, v in learn.metrics.all_metrics.items()}
metrics['epoch'] = learn.epoch
metrics['train'] = 'train' if learn.model.training else 'valid'
wandb.log(metrics)

# Save the best model based on loss
current_loss = learn.metrics.all_metrics['loss'].compute()
if learn.model.training and current_loss < self.best_loss:
self.best_loss = current_loss
self.best_model_state = learn.model.state_dict()
wandb.save('model.pth') # Save model to wandb

def after_fit(self, learn):
if self.best_model_state is not None:
torch.save(self.best_model_state, 'best_model.pth')
wandb.save('best_model.pth')


class EarlyStoppingCallback(Callback):
order = MetricsCB.order+1
def __init__(self, patience=3, min_delta=0.001):
Expand Down

0 comments on commit a67b00c

Please sign in to comment.