From a6035197e89fb86342aac62c7533aaeb8cdd653b Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Wed, 27 Sep 2023 09:52:21 +0300 Subject: [PATCH] Fix docstrings --- .../training/utils/callbacks/callbacks.py | 27 ++++++++++++++----- 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/src/super_gradients/training/utils/callbacks/callbacks.py b/src/super_gradients/training/utils/callbacks/callbacks.py index 192d44e09a..59fb3117a9 100644 --- a/src/super_gradients/training/utils/callbacks/callbacks.py +++ b/src/super_gradients/training/utils/callbacks/callbacks.py @@ -1085,12 +1085,27 @@ def __init__( max_images: int = -1, ): """ + :param metric: Metric, will be the metric which is monitored. - :param metric: - :param metric_component_name: - :param loss_to_monitor: - :param max: - :param freq: Frequency (in epochs) of performing this callback. 1 means every epoch. 2 means every other epoch. Default is 1. + :param metric_component_name: In case metric returns multiple values (as Mapping), + the value at metric.compute()[metric_component_name] will be the one monitored. + + :param loss_to_monitor: str, loss_to_monitor corresponding to the 'criterion' passed through training_params in Trainer.train(...). + Monitoring loss follows the same logic as metric_to_watch in Trainer.train(..), when watching the loss and should be: + + if hasattr(criterion, "component_names") and criterion.forward(..) returns a tuple: + "/". + + If a single item is returned rather then a tuple: + . + + When there is no such attributes and criterion.forward(..) returns a tuple: + "/"Loss_" + + :param max: bool, Whether to take the batch corresponding to the max value of the metric/loss or + the minimum (default=False). + + :param freq: int, epoch frequency to perform all of the above (default=1). :param enable_on_train_loader: Controls whether to enable this callback on the train loader. Default is False. :param enable_on_valid_loader: Controls whether to enable this callback on the valid loader. Default is True. :param max_images: Maximum images to save. If -1, save all images. @@ -1182,7 +1197,7 @@ def _on_batch_end(self, context: PhaseContext) -> None: if self.metric_component_name is not None: if not isinstance(score, Mapping) or (isinstance(score, Mapping) and self.metric_component_name not in score.keys()): raise RuntimeError( - f"metric_component_name: {self.metric_component_name} is not a component " f"of the monitored metric: {self.metric.__class__.__name__}" + f"metric_component_name: {self.metric_component_name} is not a component of the monitored metric: {self.metric.__class__.__name__}" ) score = score[self.metric_component_name] elif len(score) > 1: