From bf13606319947f7534389fe8a3a41357d2a844b8 Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Mon, 25 Sep 2023 17:58:23 +0300 Subject: [PATCH 1/3] Improve visualization callbacks --- .../training/utils/callbacks/callbacks.py | 237 ++++++++++++------ src/super_gradients/training/utils/utils.py | 25 +- 2 files changed, 168 insertions(+), 94 deletions(-) diff --git a/src/super_gradients/training/utils/callbacks/callbacks.py b/src/super_gradients/training/utils/callbacks/callbacks.py index e0b52fa327..192d44e09a 100644 --- a/src/super_gradients/training/utils/callbacks/callbacks.py +++ b/src/super_gradients/training/utils/callbacks/callbacks.py @@ -1,12 +1,12 @@ import copy +import csv import math import os import signal import time from abc import ABC, abstractmethod -from typing import List, Union, Optional, Sequence, Mapping, Tuple +from typing import List, Union, Optional, Sequence, Mapping -import csv import cv2 import numpy as np import onnx @@ -18,22 +18,21 @@ from super_gradients.common.abstractions.abstract_logger import get_logger from super_gradients.common.decorators.factory_decorator import resolve_param +from super_gradients.common.deprecate import deprecated +from super_gradients.common.environment.checkpoints_dir_utils import get_project_checkpoints_dir_path from super_gradients.common.environment.ddp_utils import multi_process_safe from super_gradients.common.environment.device_utils import device_config from super_gradients.common.factories.metrics_factory import MetricsFactory +from super_gradients.common.object_names import LRSchedulers, LRWarmups, Callbacks from super_gradients.common.plugins.deci_client import DeciClient from super_gradients.common.registry.registry import register_lr_scheduler, register_lr_warmup, register_callback, LR_SCHEDULERS_CLS_DICT, TORCH_LR_SCHEDULERS -from super_gradients.common.object_names import LRSchedulers, LRWarmups, Callbacks from super_gradients.common.sg_loggers.time_units import GlobalBatchStepNumber, EpochNumber from super_gradients.training.utils import get_param from super_gradients.training.utils.callbacks.base_callbacks import PhaseCallback, PhaseContext, Phase, Callback from super_gradients.training.utils.detection_utils import DetectionVisualization, DetectionPostPredictionCallback, cxcywh2xyxy, xyxy2cxcywh from super_gradients.training.utils.distributed_training_utils import maybe_all_reduce_tensor_average, maybe_all_gather_np_images from super_gradients.training.utils.segmentation_utils import BinarySegmentationVisualization -from super_gradients.common.environment.checkpoints_dir_utils import get_project_checkpoints_dir_path -from super_gradients.training.utils.utils import unwrap_model -from super_gradients.common.deprecate import deprecated - +from super_gradients.training.utils.utils import unwrap_model, infer_model_device, tensor_container_to_device logger = get_logger(__name__) @@ -1070,7 +1069,7 @@ class ExtremeBatchCaseVisualizationCallback(Callback, ABC): :param freq: int, epoch frequency to perform all of the above (default=1). - Inheritors should implement process_extreme_batch which returns an image, as an np.array (uint8) with shape BCHW. + Inheritors should implement process_extreme_batch which returns an image, as np.array (uint8) with shape BCHW. """ @resolve_param("metric", MetricsFactory()) @@ -1081,7 +1080,21 @@ def __init__( loss_to_monitor: Optional[str] = None, max: bool = False, freq: int = 1, + enable_on_train_loader: bool = False, + enable_on_valid_loader: bool = True, + max_images: int = -1, ): + """ + + :param metric: + :param metric_component_name: + :param loss_to_monitor: + :param max: + :param freq: Frequency (in epochs) of performing this callback. 1 means every epoch. 2 means every other epoch. Default is 1. + :param enable_on_train_loader: Controls whether to enable this callback on the train loader. Default is False. + :param enable_on_valid_loader: Controls whether to enable this callback on the valid loader. Default is True. + :param max_images: Maximum images to save. If -1, save all images. + """ super(ExtremeBatchCaseVisualizationCallback, self).__init__() if (metric and loss_to_monitor) or (metric is None and loss_to_monitor is None): @@ -1098,8 +1111,8 @@ def __init__( self.loss_to_monitor = loss_to_monitor self.max = max self.freq = freq - self.extreme_score = -1 * np.inf if max else np.inf + self.extreme_score = None self.extreme_batch = None self.extreme_preds = None self.extreme_targets = None @@ -1107,6 +1120,10 @@ def __init__( self._first_call = True self._idx_loss_tuple = None + self.enable_on_train_loader = enable_on_train_loader + self.enable_on_valid_loader = enable_on_valid_loader + self.max_images = max_images + def _set_tag_attr(self, loss_to_monitor, max, metric, metric_component_name): if metric_component_name: monitored_val_name = metric_component_name @@ -1126,43 +1143,72 @@ def process_extreme_batch(self) -> np.ndarray: """ raise NotImplementedError + def on_train_loader_start(self, context: PhaseContext) -> None: + self._reset() + + def on_train_batch_end(self, context: PhaseContext) -> None: + if self.enable_on_train_loader and context.epoch % self.freq == 0: + self._on_batch_end(context) + + def on_train_loader_end(self, context: PhaseContext) -> None: + if self.enable_on_train_loader and context.epoch % self.freq == 0: + self._gather_extreme_batch_images_and_log(context, "train") + self._reset() + + def on_validation_loader_start(self, context: PhaseContext) -> None: + self._reset() + def on_validation_batch_end(self, context: PhaseContext) -> None: - if context.epoch % self.freq == 0: - # FOR METRIC OBJECTS, RESET THEM AND COMPUTE SCORE ONLY ON BATCH. - if self.metric is not None: - self.metric.update(**context.__dict__) - score = self.metric.compute() - if self.metric_component_name is not None: - if not isinstance(score, Mapping) or (isinstance(score, Mapping) and self.metric_component_name not in score.keys()): - raise RuntimeError( - f"metric_component_name: {self.metric_component_name} is not a component " - f"of the monitored metric: {self.metric.__class__.__name__}" - ) - score = score[self.metric_component_name] - elif len(score) > 1: - raise RuntimeError(f"returned multiple values from {self.metric} but no metric_component_name has been passed to __init__.") - else: - score = score.pop(list(score.keys())[0]) - self.metric.reset() + if self.enable_on_valid_loader and context.epoch % self.freq == 0: + self._on_batch_end(context) + + def on_validation_loader_end(self, context: PhaseContext) -> None: + if self.enable_on_valid_loader and context.epoch % self.freq == 0: + self._gather_extreme_batch_images_and_log(context, "valid") + self._reset() + + def _gather_extreme_batch_images_and_log(self, context, loader_name: str): + images_to_save = self.process_extreme_batch() + images_to_save = maybe_all_gather_np_images(images_to_save) + if self.max_images > 0: + images_to_save = images_to_save[: self.max_images] + if not context.ddp_silent_mode: + context.sg_logger.add_images(tag=f"{loader_name}/{self._tag}", images=images_to_save, global_step=context.epoch, data_format="NHWC") + def _on_batch_end(self, context: PhaseContext) -> None: + if self.metric is not None: + self.metric.update(**context.__dict__) + score = self.metric.compute() + if self.metric_component_name is not None: + if not isinstance(score, Mapping) or (isinstance(score, Mapping) and self.metric_component_name not in score.keys()): + raise RuntimeError( + f"metric_component_name: {self.metric_component_name} is not a component " f"of the monitored metric: {self.metric.__class__.__name__}" + ) + score = score[self.metric_component_name] + elif len(score) > 1: + raise RuntimeError(f"returned multiple values from {self.metric} but no metric_component_name has been passed to __init__.") else: + score = score.pop(list(score.keys())[0]) + self.metric.reset() + + else: - # FOR LOSS VALUES, GET THE RIGHT COMPONENT, DERIVE IT ON THE FIRST PASS - loss_tuple = context.loss_log_items - if self._first_call: - self._init_loss_attributes(context) - score = loss_tuple[self._idx_loss_tuple] + # FOR LOSS VALUES, GET THE RIGHT COMPONENT, DERIVE IT ON THE FIRST PASS + loss_tuple = context.loss_log_items + if self._first_call: + self._init_loss_attributes(context) + score = loss_tuple[self._idx_loss_tuple].detach().cpu().item() - # IN CONTRARY TO METRICS - LOSS VALUES NEED TO BE REDUCES IN DDP - device = next(context.net.parameters()).device - score.to(device) - score = maybe_all_reduce_tensor_average(score) + # IN CONTRARY TO METRICS - LOSS VALUES NEED TO BE REDUCES IN DDP + device = infer_model_device(context.net) + score = torch.tensor(score, device=device) + score = maybe_all_reduce_tensor_average(score) - if self._is_more_extreme(score): - self.extreme_score = score - self.extreme_batch = context.inputs - self.extreme_preds = context.preds - self.extreme_targets = context.target + if self._is_more_extreme(score): + self.extreme_score = tensor_container_to_device(score, device="cpu", detach=True, non_blocking=False) + self.extreme_batch = tensor_container_to_device(context.inputs, device="cpu", detach=True, non_blocking=False) + self.extreme_preds = tensor_container_to_device(context.preds, device="cpu", detach=True, non_blocking=False) + self.extreme_targets = tensor_container_to_device(context.target, device="cpu", detach=True, non_blocking=False) def _init_loss_attributes(self, context: PhaseContext): if self.loss_to_monitor not in context.loss_logging_items_names: @@ -1170,24 +1216,28 @@ def _init_loss_attributes(self, context: PhaseContext): self._idx_loss_tuple = context.loss_logging_items_names.index(self.loss_to_monitor) self._first_call = False - def on_validation_loader_end(self, context: PhaseContext) -> None: - if context.epoch % self.freq == 0: - images_to_save = self.process_extreme_batch() - images_to_save = maybe_all_gather_np_images(images_to_save) - if not context.ddp_silent_mode: - context.sg_logger.add_images(tag=self._tag, images=images_to_save, global_step=context.epoch) - - self._reset() - def _reset(self): - self.extreme_score = -1 * np.inf if self.max else np.inf + self.extreme_score = None self.extreme_batch = None self.extreme_preds = None self.extreme_targets = None if self.metric is not None: self.metric.reset() - def _is_more_extreme(self, score: float) -> bool: + def _is_more_extreme(self, score: Union[float, torch.Tensor]) -> bool: + """ + Checks whether computed score is the more extreme than the current extreme score. + If the current score is None (first call), returns True. + :param score: A newly computed score. + :return: True if score is more extreme than the current extreme score, False otherwise. + """ + # A score can be Nan/Inf (rare but possible event when training diverges). + # In such case the both < and > operators would return False according to IEEE 754. + # As a consequence, self.extreme_inputs / self.extreme_outputs would not be updated + # and that would crash at the attempt to visualize batch. + if self.extreme_score is None: + return True + if self.max: return self.extreme_score < score else: @@ -1254,18 +1304,21 @@ class ExtremeBatchDetectionVisualizationCallback(ExtremeBatchCaseVisualizationCa When there is no such attributes and criterion.forward(..) returns a tuple: "/"Loss_" - :param max: bool, Whether to take the batch corresponding to the max value of the metric/loss or - the minimum (default=False). + :param max: bool, Whether to take the batch corresponding to the max value of the metric/loss or + the minimum (default=False). - :param freq: int, epoch frequency to perform all of the above (default=1). + :param freq: int, epoch frequency to perform all of the above (default=1). - :param classes: List[str], a list of class names corresponding to the class indices for display. - When None, will try to fetch this through a "classes" attribute of the valdiation dataset. If such attribute does - not exist an error will be raised (default=None). + :param classes: List[str], a list of class names corresponding to the class indices for display. + When None, will try to fetch this through a "classes" attribute of the valdiation dataset. If such attribute does + not exist an error will be raised (default=None). - :param normalize_targets: bool, whether to scale the target bboxes. If the bboxes returned by the validation data loader + :param normalize_targets: bool, whether to scale the target bboxes. If the bboxes returned by the validation data loader are in pixel values range, this needs to be set to True (default=False) + :param enable_on_train_loader: Controls whether to enable this callback on the train loader. Default is False. + :param enable_on_valid_loader: Controls whether to enable this callback on the valid loader. Default is True. + :param max_images: Maximum images to save. If -1, save all images. """ def __init__( @@ -1278,9 +1331,19 @@ def __init__( freq: int = 1, classes: Optional[List[str]] = None, normalize_targets: bool = False, + enable_on_train_loader: bool = False, + enable_on_valid_loader: bool = True, + max_images: int = -1, ): super(ExtremeBatchDetectionVisualizationCallback, self).__init__( - metric=metric, metric_component_name=metric_component_name, loss_to_monitor=loss_to_monitor, max=max, freq=freq + metric=metric, + metric_component_name=metric_component_name, + loss_to_monitor=loss_to_monitor, + max=max, + freq=freq, + enable_on_valid_loader=enable_on_valid_loader, + enable_on_train_loader=enable_on_train_loader, + max_images=max_images, ) self.post_prediction_callback = post_prediction_callback if classes is None: @@ -1307,10 +1370,12 @@ def universal_undo_preprocessing_fn(inputs: torch.Tensor) -> np.ndarray: inputs = np.ascontiguousarray(inputs, dtype=np.uint8) return inputs - def process_extreme_batch(self) -> Tuple[np.ndarray, np.ndarray]: + def process_extreme_batch(self) -> np.ndarray: """ - Processes the extreme batch, and returns 2 image batches for visualization - one with predictions and one with GT boxes. - :return:Tuple[np.ndarray, np.ndarray], the predictions batch, the GT batch + Processes the extreme batch, and returns list of images for visualization. + Default implementations stacks GT and prediction overlays horisontally. + + :return: np.ndarray A 4D tensor of [BHWC] shape with visualizations of the extreme batch. """ inputs = self.extreme_batch preds = self.post_prediction_callback(self.extreme_preds, self.extreme_batch.device) @@ -1334,25 +1399,16 @@ def process_extreme_batch(self) -> Tuple[np.ndarray, np.ndarray]: ) images_to_save_gt = np.stack(images_to_save_gt) - return images_to_save_preds, images_to_save_gt + # Stack the predictions and GT images together + return np.concatenate([images_to_save_gt, images_to_save_preds], axis=2) - def on_validation_loader_end(self, context: PhaseContext) -> None: + def on_validation_loader_start(self, context: PhaseContext) -> None: if self.classes is None: if hasattr(context.valid_loader.dataset, "classes"): self.classes = context.valid_loader.dataset.classes - else: raise RuntimeError("Couldn't fetch classes from valid_loader, please pass classes explicitly") - if context.epoch % self.freq == 0: - images_to_save_preds, images_to_save_gt = self.process_extreme_batch() - images_to_save_preds = maybe_all_gather_np_images(images_to_save_preds) - images_to_save_gt = maybe_all_gather_np_images(images_to_save_gt) - - if not context.ddp_silent_mode: - context.sg_logger.add_images(tag=f"{self._tag}_preds", images=images_to_save_preds, global_step=context.epoch, data_format="NHWC") - context.sg_logger.add_images(tag=f"{self._tag}_GT", images=images_to_save_gt, global_step=context.epoch, data_format="NHWC") - - self._reset() + super().on_validation_loader_start(context) @register_callback("ExtremeBatchSegVisualizationCallback") @@ -1405,12 +1461,14 @@ class ExtremeBatchSegVisualizationCallback(ExtremeBatchCaseVisualizationCallback When there is no such attributesand criterion.forward(..) returns a tuple: "/"Loss_" - :param max: bool, Whether to take the batch corresponding to the max value of the metric/loss or - the minimum (default=False). - - :param freq: int, epoch frequency to perform all of the above (default=1). + :param max: bool, Whether to take the batch corresponding to the max value of the metric/loss or + the minimum (default=False). + :param freq: int, epoch frequency to perform all of the above (default=1). + :param enable_on_train_loader: Controls whether to enable this callback on the train loader. Default is False. + :param enable_on_valid_loader: Controls whether to enable this callback on the valid loader. Default is True. + :param max_images: Maximum images to save. If -1, save all images. """ def __init__( @@ -1421,13 +1479,24 @@ def __init__( max: bool = False, freq: int = 1, ignore_idx: int = -1, + enable_on_train_loader: bool = False, + enable_on_valid_loader: bool = True, + max_images: int = -1, ): super(ExtremeBatchSegVisualizationCallback, self).__init__( - metric=metric, metric_component_name=metric_component_name, loss_to_monitor=loss_to_monitor, max=max, freq=freq + metric=metric, + metric_component_name=metric_component_name, + loss_to_monitor=loss_to_monitor, + max=max, + freq=freq, + enable_on_valid_loader=enable_on_valid_loader, + enable_on_train_loader=enable_on_train_loader, + max_images=max_images, ) self.ignore_idx = ignore_idx - def process_extreme_batch(self) -> np.array: + @torch.no_grad() + def process_extreme_batch(self) -> np.ndarray: inputs = self.extreme_batch inputs -= inputs.min() inputs /= inputs.max() @@ -1445,6 +1514,8 @@ def process_extreme_batch(self) -> np.array: colors = ["green", "red"] images_to_save = [] for i in range(len(inputs)): - images_to_save.append(draw_segmentation_masks(inputs[i].cpu(), overlay[i], colors=colors, alpha=0.4).detach().numpy()) - images_to_save = np.array(images_to_save) + image = draw_segmentation_masks(inputs[i].cpu(), overlay[i].cpu(), colors=colors, alpha=0.4).numpy() + image = np.transpose(image, (1, 2, 0)) + images_to_save.append(image) + images_to_save = np.stack(images_to_save) return images_to_save diff --git a/src/super_gradients/training/utils/utils.py b/src/super_gradients/training/utils/utils.py index 48f7545227..c04f3e04ac 100755 --- a/src/super_gradients/training/utils/utils.py +++ b/src/super_gradients/training/utils/utils.py @@ -216,23 +216,26 @@ def average(self): # else tuple((self._sum / self._count).cpu().numpy()) -def tensor_container_to_device(obj: Union[torch.Tensor, tuple, list, dict], device: str, non_blocking=True): +def tensor_container_to_device(obj: Union[torch.Tensor, tuple, list, dict], device: str, non_blocking=True, detach: bool = False): """ - recursively send compounded objects to device (sending all tensors to device and maintaining structure) - :param obj the object to send to device (list / tuple / tensor / dict) - :param device: device to send the tensors to - :param non_blocking: used for DistributedDataParallel - :returns an object with the same structure (tensors, lists, tuples) with the device pointers (like - the return value of Tensor.to(device) + Recursively send compounded objects to device (sending all tensors to device and maintaining structure) + :param obj the object to send to device (list / tuple / tensor / dict) + :param device: device to send the tensors to + :param non_blocking: used for DistributedDataParallel + :param detach: detach the tensors from the graph + :returns an object with the same structure (tensors, lists, tuples) with the device pointers (like + the return value of Tensor.to(device) """ if isinstance(obj, torch.Tensor): + if detach: + obj = obj.detach() return obj.to(device, non_blocking=non_blocking) elif isinstance(obj, tuple): - return tuple(tensor_container_to_device(x, device, non_blocking=non_blocking) for x in obj) + return tuple(tensor_container_to_device(x, device, non_blocking=non_blocking, detach=detach) for x in obj) elif isinstance(obj, list): - return [tensor_container_to_device(x, device, non_blocking=non_blocking) for x in obj] - elif isinstance(obj, dict): - return {k: tensor_container_to_device(v, device, non_blocking=non_blocking) for k, v in obj.items()} + return [tensor_container_to_device(x, device, non_blocking=non_blocking, detach=detach) for x in obj] + elif isinstance(obj, (dict, typing.Mapping)): + return {k: tensor_container_to_device(v, device, non_blocking=non_blocking, detach=detach) for k, v in obj.items()} else: return obj From a6035197e89fb86342aac62c7533aaeb8cdd653b Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Wed, 27 Sep 2023 09:52:21 +0300 Subject: [PATCH 2/3] Fix docstrings --- .../training/utils/callbacks/callbacks.py | 27 ++++++++++++++----- 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/src/super_gradients/training/utils/callbacks/callbacks.py b/src/super_gradients/training/utils/callbacks/callbacks.py index 192d44e09a..59fb3117a9 100644 --- a/src/super_gradients/training/utils/callbacks/callbacks.py +++ b/src/super_gradients/training/utils/callbacks/callbacks.py @@ -1085,12 +1085,27 @@ def __init__( max_images: int = -1, ): """ + :param metric: Metric, will be the metric which is monitored. - :param metric: - :param metric_component_name: - :param loss_to_monitor: - :param max: - :param freq: Frequency (in epochs) of performing this callback. 1 means every epoch. 2 means every other epoch. Default is 1. + :param metric_component_name: In case metric returns multiple values (as Mapping), + the value at metric.compute()[metric_component_name] will be the one monitored. + + :param loss_to_monitor: str, loss_to_monitor corresponding to the 'criterion' passed through training_params in Trainer.train(...). + Monitoring loss follows the same logic as metric_to_watch in Trainer.train(..), when watching the loss and should be: + + if hasattr(criterion, "component_names") and criterion.forward(..) returns a tuple: + "/". + + If a single item is returned rather then a tuple: + . + + When there is no such attributes and criterion.forward(..) returns a tuple: + "/"Loss_" + + :param max: bool, Whether to take the batch corresponding to the max value of the metric/loss or + the minimum (default=False). + + :param freq: int, epoch frequency to perform all of the above (default=1). :param enable_on_train_loader: Controls whether to enable this callback on the train loader. Default is False. :param enable_on_valid_loader: Controls whether to enable this callback on the valid loader. Default is True. :param max_images: Maximum images to save. If -1, save all images. @@ -1182,7 +1197,7 @@ def _on_batch_end(self, context: PhaseContext) -> None: if self.metric_component_name is not None: if not isinstance(score, Mapping) or (isinstance(score, Mapping) and self.metric_component_name not in score.keys()): raise RuntimeError( - f"metric_component_name: {self.metric_component_name} is not a component " f"of the monitored metric: {self.metric.__class__.__name__}" + f"metric_component_name: {self.metric_component_name} is not a component of the monitored metric: {self.metric.__class__.__name__}" ) score = score[self.metric_component_name] elif len(score) > 1: From 08c643f19ffa35147b60787aa46db9269076b163 Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Wed, 27 Sep 2023 13:57:57 +0300 Subject: [PATCH 3/3] Added tests and improved documentation to reflect what is the expected return type and layout (channels last) of the image tensor for visualization --- .../training/utils/callbacks/callbacks.py | 18 ++++++++++-------- tests/unit_tests/extreme_batch_cb_test.py | 18 ++++++++++++++++++ 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/src/super_gradients/training/utils/callbacks/callbacks.py b/src/super_gradients/training/utils/callbacks/callbacks.py index 59fb3117a9..afdcdad04f 100644 --- a/src/super_gradients/training/utils/callbacks/callbacks.py +++ b/src/super_gradients/training/utils/callbacks/callbacks.py @@ -1069,7 +1069,7 @@ class ExtremeBatchCaseVisualizationCallback(Callback, ABC): :param freq: int, epoch frequency to perform all of the above (default=1). - Inheritors should implement process_extreme_batch which returns an image, as np.array (uint8) with shape BCHW. + Inheritors should implement process_extreme_batch which returns an image, as np.ndarray (uint8) with shape BHWC. """ @resolve_param("metric", MetricsFactory()) @@ -1153,7 +1153,7 @@ def process_extreme_batch(self) -> np.ndarray: """ This method is called right before adding the images to the in SGLoggger (inside the on_validation_loader_end call). It should process self.extreme_batch, self.extreme_preds and self.extreme_targets and output the images, as np.ndarrray. - Output should be of shape N,3,H,W and uint8. + Output should be of shape N,H,W,3 and uint8. :return: images to save, np.ndarray """ raise NotImplementedError @@ -1373,11 +1373,13 @@ def __init__( def universal_undo_preprocessing_fn(inputs: torch.Tensor) -> np.ndarray: """ A universal reversing of preprocessing to be passed to DetectionVisualization.visualize_batch's undo_preprocessing_func kwarg. - :param inputs: - :return: + This function scales input tensor to 0..255 range, and cast it to uint8 dtype. + + :param inputs: Input 4D tensor of images in BCHW format with unknown normalization. + :return: Numpy 4D tensor of images in BHWC format, normalized to 0..255 range (uint8). """ inputs -= inputs.min() - inputs /= inputs.max() + inputs /= inputs.max() + 1e-8 inputs *= 255 inputs = inputs.to(torch.uint8) inputs = inputs.cpu().numpy() @@ -1390,7 +1392,7 @@ def process_extreme_batch(self) -> np.ndarray: Processes the extreme batch, and returns list of images for visualization. Default implementations stacks GT and prediction overlays horisontally. - :return: np.ndarray A 4D tensor of [BHWC] shape with visualizations of the extreme batch. + :return: np.ndarray A 4D tensor of BHWC shape with visualizations of the extreme batch. """ inputs = self.extreme_batch preds = self.post_prediction_callback(self.extreme_preds, self.extreme_batch.device) @@ -1464,7 +1466,7 @@ class ExtremeBatchSegVisualizationCallback(ExtremeBatchCaseVisualizationCallback :param metric_component_name: In case metric returns multiple values (as Mapping), the value at metric.compute()[metric_component_name] will be the one monitored. - :param loss_to_monitor: str, loss_to_monitor corresponfing to the 'criterion' passed through training_params in Trainer.train(...). + :param loss_to_monitor: str, loss_to_monitor corresponding to the 'criterion' passed through training_params in Trainer.train(...). Monitoring loss follows the same logic as metric_to_watch in Trainer.train(..), when watching the loss and should be: if hasattr(criterion, "component_names") and criterion.forward(..) returns a tuple: @@ -1473,7 +1475,7 @@ class ExtremeBatchSegVisualizationCallback(ExtremeBatchCaseVisualizationCallback If a single item is returned rather then a tuple: . - When there is no such attributesand criterion.forward(..) returns a tuple: + When there is no such attributes and criterion.forward(..) returns a tuple: "/"Loss_" :param max: bool, Whether to take the batch corresponding to the max value of the metric/loss or diff --git a/tests/unit_tests/extreme_batch_cb_test.py b/tests/unit_tests/extreme_batch_cb_test.py index 26bfd636a2..8816cf39c9 100644 --- a/tests/unit_tests/extreme_batch_cb_test.py +++ b/tests/unit_tests/extreme_batch_cb_test.py @@ -122,6 +122,24 @@ def test_segmentation_extreme_batch_with_loss_sanity(self): model=model, training_params=self.seg_training_params, train_loader=segmentation_test_dataloader(), valid_loader=segmentation_test_dataloader() ) + def test_segmentation_extreme_batch_train_only(self): + trainer, model = setup_trainer_and_model_seg("test_segmentation_extreme_batch_train_only") + self.seg_training_params["phase_callbacks"] = [ + ExtremeBatchSegVisualizationCallback(loss_to_monitor="DDRNetLoss/aux_loss1", enable_on_train_loader=True, enable_on_valid_loader=False) + ] + trainer.train( + model=model, training_params=self.seg_training_params, train_loader=segmentation_test_dataloader(), valid_loader=segmentation_test_dataloader() + ) + + def test_segmentation_extreme_batch_train_and_valid(self): + trainer, model = setup_trainer_and_model_seg("test_segmentation_extreme_batch_train_and_valid") + self.seg_training_params["phase_callbacks"] = [ + ExtremeBatchSegVisualizationCallback(loss_to_monitor="DDRNetLoss/aux_loss1", enable_on_train_loader=True, enable_on_valid_loader=True) + ] + trainer.train( + model=model, training_params=self.seg_training_params, train_loader=segmentation_test_dataloader(), valid_loader=segmentation_test_dataloader() + ) + if __name__ == "__main__": unittest.main()