Skip to content

Commit

Permalink
feat: option for evaluation with loss in Trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexandre-Delplanque committed Apr 28, 2023
1 parent ff94a5e commit d480319
Showing 1 changed file with 12 additions and 9 deletions.
21 changes: 12 additions & 9 deletions animaloc/train/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
Please contact the author Alexandre Delplanque ([email protected]) for any questions.
Last modification: March 29, 2023
Last modification: April 28, 2023
"""
__author__ = "Alexandre Delplanque"
__license__ = "CC BY-NC-SA 4.0"
Expand Down Expand Up @@ -221,7 +221,7 @@ def start(
warmup_iters: Optional[int] = None,
checkpoints: str = 'best',
select: str = 'min',
validate_on: str = 'recall',
validate_on: str = 'all',
wandb_flag: bool = False
) -> torch.nn.Module:
''' Start training from epoch 1
Expand All @@ -242,10 +242,11 @@ def start(
- 'min' (default), for selecting the epoch that yields to a minimum validation value,
- 'max', for selecting the epoch that yields to a maximum validation value.
Defaults to 'min'.
validate_on (str, optional): metrics used for validation (i.e. best model and auto-lr) when
custom evaluator is specified. Possible values are: 'recall', 'precision', 'f1_score',
'mse', 'mae', and 'rmse'.
Defauts to 'recall'
validate_on (str, optional): metrics/loss used for validation (i.e. best model and auto-lr).
For validation with losses, possible values are the names returned by the model, or 'all'
for using the sum of all losses (default). Possible values for evaluator are: 'recall',
'precision', 'f1_score', 'mse', 'mae', 'rmse', 'accuracy' or 'mAP'.
Defauts to 'all'
wandb_flag (bool, optional): set to True to log on Weight & Biases. Defaults to False.
Returns:
Expand Down Expand Up @@ -291,7 +292,7 @@ def start(

elif self.val_dataloader is not None:
val_flag = True
val_output = self.evaluate(epoch, wandb_flag=wandb_flag)
val_output = self.evaluate(epoch, wandb_flag=wandb_flag, returns=validate_on)
if wandb_flag:
wandb.log({'val_loss': val_output, 'epoch': epoch})

Expand Down Expand Up @@ -408,7 +409,7 @@ def resume(

elif self.val_dataloader is not None:
val_flag = True
val_output = self.evaluate(epoch, wandb_flag=wandb_flag)
val_output = self.evaluate(epoch, wandb_flag=wandb_flag, returns=validate_on)
if wandb_flag:
wandb.log({'val_loss': val_output, 'epoch': epoch})

Expand Down Expand Up @@ -443,7 +444,7 @@ def resume(
return self.model

@torch.no_grad()
def evaluate(self, epoch: int, reduction: str = 'mean', wandb_flag: bool = False) -> float:
def evaluate(self, epoch: int, reduction: str = 'mean', wandb_flag: bool = False, returns: str = 'all') -> float:

self.model.eval()

Expand All @@ -458,6 +459,8 @@ def evaluate(self, epoch: int, reduction: str = 'mean', wandb_flag: bool = False
output, loss_dict = self.model(images, targets)

losses = sum(loss for loss in loss_dict.values())
if returns != 'all':
losses = loss_dict[returns]

loss_dict_reduced = reduce_dict(loss_dict)
losses_reduced = sum(loss for loss in loss_dict_reduced.values())
Expand Down

0 comments on commit d480319

Please sign in to comment.