-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: option for evaluation with loss in Trainer
- Loading branch information
1 parent
ff94a5e
commit d480319
Showing
1 changed file
with
12 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,7 +8,7 @@ | |
Please contact the author Alexandre Delplanque ([email protected]) for any questions. | ||
Last modification: March 29, 2023 | ||
Last modification: April 28, 2023 | ||
""" | ||
__author__ = "Alexandre Delplanque" | ||
__license__ = "CC BY-NC-SA 4.0" | ||
|
@@ -221,7 +221,7 @@ def start( | |
warmup_iters: Optional[int] = None, | ||
checkpoints: str = 'best', | ||
select: str = 'min', | ||
validate_on: str = 'recall', | ||
validate_on: str = 'all', | ||
wandb_flag: bool = False | ||
) -> torch.nn.Module: | ||
''' Start training from epoch 1 | ||
|
@@ -242,10 +242,11 @@ def start( | |
- 'min' (default), for selecting the epoch that yields to a minimum validation value, | ||
- 'max', for selecting the epoch that yields to a maximum validation value. | ||
Defaults to 'min'. | ||
validate_on (str, optional): metrics used for validation (i.e. best model and auto-lr) when | ||
custom evaluator is specified. Possible values are: 'recall', 'precision', 'f1_score', | ||
'mse', 'mae', and 'rmse'. | ||
Defauts to 'recall' | ||
validate_on (str, optional): metrics/loss used for validation (i.e. best model and auto-lr). | ||
For validation with losses, possible values are the names returned by the model, or 'all' | ||
for using the sum of all losses (default). Possible values for evaluator are: 'recall', | ||
'precision', 'f1_score', 'mse', 'mae', 'rmse', 'accuracy' or 'mAP'. | ||
Defauts to 'all' | ||
wandb_flag (bool, optional): set to True to log on Weight & Biases. Defaults to False. | ||
Returns: | ||
|
@@ -291,7 +292,7 @@ def start( | |
|
||
elif self.val_dataloader is not None: | ||
val_flag = True | ||
val_output = self.evaluate(epoch, wandb_flag=wandb_flag) | ||
val_output = self.evaluate(epoch, wandb_flag=wandb_flag, returns=validate_on) | ||
if wandb_flag: | ||
wandb.log({'val_loss': val_output, 'epoch': epoch}) | ||
|
||
|
@@ -408,7 +409,7 @@ def resume( | |
|
||
elif self.val_dataloader is not None: | ||
val_flag = True | ||
val_output = self.evaluate(epoch, wandb_flag=wandb_flag) | ||
val_output = self.evaluate(epoch, wandb_flag=wandb_flag, returns=validate_on) | ||
if wandb_flag: | ||
wandb.log({'val_loss': val_output, 'epoch': epoch}) | ||
|
||
|
@@ -443,7 +444,7 @@ def resume( | |
return self.model | ||
|
||
@torch.no_grad() | ||
def evaluate(self, epoch: int, reduction: str = 'mean', wandb_flag: bool = False) -> float: | ||
def evaluate(self, epoch: int, reduction: str = 'mean', wandb_flag: bool = False, returns: str = 'all') -> float: | ||
|
||
self.model.eval() | ||
|
||
|
@@ -458,6 +459,8 @@ def evaluate(self, epoch: int, reduction: str = 'mean', wandb_flag: bool = False | |
output, loss_dict = self.model(images, targets) | ||
|
||
losses = sum(loss for loss in loss_dict.values()) | ||
if returns != 'all': | ||
losses = loss_dict[returns] | ||
|
||
loss_dict_reduced = reduce_dict(loss_dict) | ||
losses_reduced = sum(loss for loss in loss_dict_reduced.values()) | ||
|