From 5c45ee31538f761ed5e1e67541991f2dc4e33fe4 Mon Sep 17 00:00:00 2001 From: shayaharon Date: Sun, 16 Jul 2023 14:37:32 +0300 Subject: [PATCH 1/9] tested version --- .../training/sg_trainer/sg_trainer.py | 5 +- .../utils/callbacks/base_callbacks.py | 2 + .../training/utils/callbacks/callbacks.py | 209 +++++++++++++++++- 3 files changed, 214 insertions(+), 2 deletions(-) diff --git a/src/super_gradients/training/sg_trainer/sg_trainer.py b/src/super_gradients/training/sg_trainer/sg_trainer.py index 61d5893897..c948e5c77a 100755 --- a/src/super_gradients/training/sg_trainer/sg_trainer.py +++ b/src/super_gradients/training/sg_trainer/sg_trainer.py @@ -441,7 +441,7 @@ def _train_epoch(self, context: PhaseContext, silent_mode: bool = False) -> tupl # COMPUTE THE LOSS FOR BACK PROP + EXTRA METRICS COMPUTED DURING THE LOSS FORWARD PASS loss, loss_log_items = self._get_losses(outputs, targets) - context.update_context(preds=outputs, loss_log_items=loss_log_items) + context.update_context(preds=outputs, loss_log_items=loss_log_items, loss_logging_items_names=self.loss_logging_items_names) self.phase_callback_handler.on_train_batch_loss_end(context) if not self.ddp_silent_mode and batch_idx == 0: @@ -1316,6 +1316,7 @@ def forward(self, inputs, targets): metric_to_watch=self.metric_to_watch, device=device_config.device, ema_model=self.ema_model, + valid_metrics=self.valid_metrics, ) self.phase_callback_handler.on_training_start(context) @@ -1986,6 +1987,7 @@ def evaluate( lr_warmup_epochs = self.training_params.lr_warmup_epochs if self.training_params else None context = PhaseContext( + net=self.net, epoch=epoch, metrics_compute_fn=metrics, loss_avg_meter=loss_avg_meter, @@ -1995,6 +1997,7 @@ def evaluate( sg_logger=self.sg_logger, train_loader=self.train_loader, valid_loader=self.valid_loader, + loss_logging_items_names=self.loss_logging_items_names, ) with tqdm(data_loader, bar_format="{l_bar}{bar:10}{r_bar}", dynamic_ncols=True, disable=silent_mode) as progress_bar_data_loader: diff --git a/src/super_gradients/training/utils/callbacks/base_callbacks.py b/src/super_gradients/training/utils/callbacks/base_callbacks.py index b6edacc75c..4f11d882dd 100644 --- a/src/super_gradients/training/utils/callbacks/base_callbacks.py +++ b/src/super_gradients/training/utils/callbacks/base_callbacks.py @@ -53,6 +53,7 @@ def __init__( metric_to_watch=None, valid_metrics=None, ema_model=None, + loss_logging_items_names=None, ): self.epoch = epoch self.batch_idx = batch_idx @@ -82,6 +83,7 @@ def __init__( self.metric_to_watch = metric_to_watch self.valid_metrics = valid_metrics self.ema_model = ema_model + self.loss_logging_items_names = loss_logging_items_names def update_context(self, **kwargs): for attr, attr_val in kwargs.items(): diff --git a/src/super_gradients/training/utils/callbacks/callbacks.py b/src/super_gradients/training/utils/callbacks/callbacks.py index e5eab1eeba..4f0e06a88d 100644 --- a/src/super_gradients/training/utils/callbacks/callbacks.py +++ b/src/super_gradients/training/utils/callbacks/callbacks.py @@ -12,10 +12,13 @@ import onnxruntime import torch from deprecated import deprecated +from torch.distributed import gather_object, get_rank from torch.utils.data import DataLoader +from torchmetrics import MetricCollection from super_gradients.common.abstractions.abstract_logger import get_logger -from super_gradients.common.environment.ddp_utils import multi_process_safe +from super_gradients.common.environment.ddp_utils import multi_process_safe, is_distributed +from super_gradients.common.environment.device_utils import device_config from super_gradients.common.plugins.deci_client import DeciClient from super_gradients.common.registry.registry import register_lr_scheduler, register_lr_warmup, register_callback, LR_SCHEDULERS_CLS_DICT, TORCH_LR_SCHEDULERS from super_gradients.common.object_names import LRSchedulers, LRWarmups, Callbacks @@ -23,9 +26,11 @@ from super_gradients.training.utils import get_param from super_gradients.training.utils.callbacks.base_callbacks import PhaseCallback, PhaseContext, Phase, Callback from super_gradients.training.utils.detection_utils import DetectionVisualization, DetectionPostPredictionCallback +from super_gradients.training.utils.distributed_training_utils import distributed_all_reduce_tensor_average, get_world_size from super_gradients.training.utils.segmentation_utils import BinarySegmentationVisualization from super_gradients.common.environment.checkpoints_dir_utils import get_project_checkpoints_dir_path from super_gradients.training.utils.utils import unwrap_model +from torchvision.utils import draw_segmentation_masks logger = get_logger(__name__) @@ -948,3 +953,205 @@ def create_lr_scheduler_callback( raise ValueError(f"Unknown lr_mode: {lr_mode}") return sg_lr_callback + + +class ExtremeBatchCaseVisualizationCallback(Callback): + """ + ExtremeBatchCaseVisualizationCallback + + A base class for visualizing worst/best validation batches in an epoch + according to some metric or loss value, with Full DDP support. + + Images are saved with training_hyperparams.sg_logger. + + :param metric_name: str,will be the metric which the model checkpoint will be saved according to, and can be set to any + of the following: + + a metric name (str) of one of the metric objects from the training_hyperparams.valid_metrics_list + + a "component_name" if some metric in valid_metrics_list has an attribute component_names. In such cas it + is a list referring to the names of each entry in the output metric (torch tensor of size n). + + one of "loss_logging_items_names" i.e which will correspond to an item returned during the + loss function's forward pass (see loss docs in Trainer.train(..)). + + :param max: bool, Whether to take the batch corresponding to the max value of the metric/loss or + the minimum (default=False). + + :param freq: int, epoch frequency to perform all of the above (default=1). + + Inheritors should implement process_extreme_batch which returns an image, as an np.array (uint8) with shape BCHW. + """ + + def __init__(self, metric_name: str, max: bool = False, freq: int = 1): + self.metric_name = metric_name + self.metric = None + self.max = max + self.freq = freq + self.extreme_score = -1 * np.inf if max else np.inf + + self.extreme_batch = None + self.extreme_preds = None + self.extreme_targets = None + + self._first_call = True + self._idx_loss_tuple = None + self._tag = f"max_{self.metric_name}_batch" if self.max else f"min_{self.metric_name}_batch" + + super(ExtremeBatchCaseVisualizationCallback, self).__init__() + + def process_extreme_batch(self) -> np.array: + raise NotImplementedError + + def on_training_start(self, context: PhaseContext) -> None: + """ + On train start we set the metric (if the metric_name does not corresponf to a loss). + :param context: Phase context + :return: + """ + if not hasattr(context.valid_metrics, self.metric_name): + for metric_name, metric in context.valid_metrics.items(): + if hasattr(metric, "greater_component_is_better") and self.metric_name in metric.greater_component_is_better.keys(): + # WRAP METRIC WITH METRIC COLLECTION TO FILTER ONLY THE NEEDED ARGUMENTS FOR THE METRIC UPDATE + self.metric = MetricCollection(copy.deepcopy(metric)) + self.metric.to(device_config.device) + else: + self.metric = MetricCollection(copy.deepcopy(getattr(context.valid_metrics, self.metric_name))) + self.metric.to(device_config.device) + + def on_validation_batch_end(self, context: PhaseContext) -> None: + if context.epoch % self.freq == 0: + # FOR METRIC OBJECTS, RESET THEM AND COMPUTE SCORE ONLY ON BATCH. + if self.metric is not None: + self.metric.reset() + self.metric.update(**context.__dict__) + score = self.metric.compute()[self.metric_name] + else: + + # FOR LOSS VALUES, GET THE RIGHT COMPONENT, DERRIVE IT ON THE FIRST PASS + loss_tuple = context.loss_log_items + if self._first_call: + self._init_loss_attributes(context, loss_tuple) + score = loss_tuple[self._idx_loss_tuple] + + # IN CONTRARY TO METRICS - LOSS VALUES NEED TO BE REDUCES IN DDP + if is_distributed(): + device = next(context.net.parameters()).device + score = distributed_all_reduce_tensor_average(tensor=score.to(device), n=torch.distributed.get_world_size()) + + if self._is_more_extreme(score): + self.extreme_score = score + self.extreme_batch = context.inputs + self.extreme_preds = context.preds + self.extreme_targets = context.target + + def _init_loss_attributes(self, context: PhaseContext, loss_tuple: tuple): + if self.metric_name not in context.loss_logging_items_names: + raise ValueError(f"{self.metric_name} not a validation metric, loss or loss component.") + self._idx_loss_tuple = context.loss_logging_items_names.index(self.metric_name) + self._first_call = False + + def on_validation_loader_end(self, context: PhaseContext) -> None: + if context.epoch % self.freq == 0: + images_to_save = self.process_extreme_batch() + # + if is_distributed(): + rank = get_rank() + output_container = [None for _ in range(get_world_size())] + gather_object(images_to_save, output_container if rank == 0 else None, dst=0) + if rank == 0: + images_to_save = np.concatenate(output_container, 0) + if not context.ddp_silent_mode: + context.sg_logger.add_images(tag=self._tag, images=images_to_save, global_step=context.epoch) + + self._reset() + + def _reset(self): + self.extreme_score = -1 * np.inf if self.max else np.inf + self.extreme_batch = None + self.extreme_preds = None + self.extreme_targets = None + if self.metric is not None: + self.metric.reset() + + def _is_more_extreme(self, score: float) -> bool: + if self.max: + return self.extreme_score < score + else: + return self.extreme_score > score + + +@register_callback("ExtremeBatchSegVisualizationCallback") +class ExtremeBatchSegVisualizationCallback(ExtremeBatchCaseVisualizationCallback): + """ + ExtremeBatchSegVisualizationCallback + + Visualizes worst/best batch in an epoch, for segmentation. + Assumes context.preds in validation is a score tensor of shape BCHW, or a tuple whose first item is one. + + True predictions will be marked with green, false ones with red. + + Example usage in training_params definition: + + training_hyperparams ={ + ... + "phase_callbacks": + [ExtremeBatchSegVisualizationCallback( + metric_name=IoU' + max=False + ignore_idx=19), + ExtremeBatchSegVisualizationCallback( + metric_name="LabelSmoothingCrossEntropyLoss" + max=True + ignore_idx=19)] + ...} + + + :param metric_name: str,will be the metric which the model checkpoint will be saved according to, and can be set to any + of the following: + + a metric name (str) of one of the metric objects from the training_hyperparams.valid_metrics_list + + a "component_name" if some metric in valid_metrics_list has an attribute component_names. In such cas it + is a list referring to the names of each entry in the output metric (torch tensor of size n). + + one of "loss_logging_items_names" i.e which will correspond to an item returned during the + loss function's forward pass (see loss docs in Trainer.train(..)). + + :param max: bool, Whether to take the batch corresponding to the max value of the metric/loss or + the minimum (default=False). + + :param freq: int, epoch frequency to perform all of the above (default=1). + + + :param ignore_idx: int, any prediction of a coordinate in the output image, s.t the ground truth of it is this + value will not be colored in green or in red (default=-1). + + + """ + + def __init__(self, metric_name: str, max: bool = False, freq: int = 1, ignore_idx: int = -1): + super(ExtremeBatchSegVisualizationCallback, self).__init__(metric_name=metric_name, max=max, freq=freq) + self.ignore_idx = ignore_idx + + def process_extreme_batch(self) -> np.array: + inputs = self.extreme_batch + inputs -= inputs.min() + inputs /= inputs.max() + inputs *= 255 + inputs = inputs.to(torch.uint8) + preds = self.extreme_preds + if isinstance(preds, tuple): + preds = preds[0] + preds = preds.argmax(1) + p_mask = preds == self.extreme_targets + n_mask = preds != self.extreme_targets + p_mask[self.extreme_targets == self.ignore_idx] = False + n_mask[self.extreme_targets == self.ignore_idx] = False + overlay = torch.cat([p_mask.unsqueeze(1), n_mask.unsqueeze(1)], 1) + colors = ["green", "red"] + images_to_save = [] + for i in range(len(inputs)): + images_to_save.append(draw_segmentation_masks(inputs[i].cpu(), overlay[i], colors=colors, alpha=0.4).detach().numpy()) + images_to_save = np.array(images_to_save) + return images_to_save From 9bb4c9b0957b9cffb27e10a96450a0c47c5235ec Mon Sep 17 00:00:00 2001 From: shayaharon Date: Mon, 17 Jul 2023 14:24:55 +0300 Subject: [PATCH 2/9] changed base to abc and abstractmethod --- src/super_gradients/training/utils/callbacks/callbacks.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/super_gradients/training/utils/callbacks/callbacks.py b/src/super_gradients/training/utils/callbacks/callbacks.py index 4f0e06a88d..faa5b83f55 100644 --- a/src/super_gradients/training/utils/callbacks/callbacks.py +++ b/src/super_gradients/training/utils/callbacks/callbacks.py @@ -3,6 +3,7 @@ import os import signal import time +from abc import ABC, abstractmethod from typing import List, Union, Optional, Sequence, Mapping import csv @@ -955,7 +956,7 @@ def create_lr_scheduler_callback( return sg_lr_callback -class ExtremeBatchCaseVisualizationCallback(Callback): +class ExtremeBatchCaseVisualizationCallback(Callback, ABC): """ ExtremeBatchCaseVisualizationCallback @@ -1000,6 +1001,7 @@ def __init__(self, metric_name: str, max: bool = False, freq: int = 1): super(ExtremeBatchCaseVisualizationCallback, self).__init__() + @abstractmethod def process_extreme_batch(self) -> np.array: raise NotImplementedError From bbd2f797f030774c4cca32596c51ae7110f04081 Mon Sep 17 00:00:00 2001 From: shayaharon Date: Mon, 17 Jul 2023 20:10:29 +0300 Subject: [PATCH 3/9] comments wip --- .../training/utils/callbacks/callbacks.py | 19 ++++------- .../utils/distributed_training_utils.py | 34 +++++++++++++++++++ 2 files changed, 40 insertions(+), 13 deletions(-) diff --git a/src/super_gradients/training/utils/callbacks/callbacks.py b/src/super_gradients/training/utils/callbacks/callbacks.py index faa5b83f55..98ce8bd092 100644 --- a/src/super_gradients/training/utils/callbacks/callbacks.py +++ b/src/super_gradients/training/utils/callbacks/callbacks.py @@ -13,12 +13,11 @@ import onnxruntime import torch from deprecated import deprecated -from torch.distributed import gather_object, get_rank from torch.utils.data import DataLoader from torchmetrics import MetricCollection from super_gradients.common.abstractions.abstract_logger import get_logger -from super_gradients.common.environment.ddp_utils import multi_process_safe, is_distributed +from super_gradients.common.environment.ddp_utils import multi_process_safe from super_gradients.common.environment.device_utils import device_config from super_gradients.common.plugins.deci_client import DeciClient from super_gradients.common.registry.registry import register_lr_scheduler, register_lr_warmup, register_callback, LR_SCHEDULERS_CLS_DICT, TORCH_LR_SCHEDULERS @@ -27,7 +26,7 @@ from super_gradients.training.utils import get_param from super_gradients.training.utils.callbacks.base_callbacks import PhaseCallback, PhaseContext, Phase, Callback from super_gradients.training.utils.detection_utils import DetectionVisualization, DetectionPostPredictionCallback -from super_gradients.training.utils.distributed_training_utils import distributed_all_reduce_tensor_average, get_world_size +from super_gradients.training.utils.distributed_training_utils import maybe_all_reduce, maybe_all_gather from super_gradients.training.utils.segmentation_utils import BinarySegmentationVisualization from super_gradients.common.environment.checkpoints_dir_utils import get_project_checkpoints_dir_path from super_gradients.training.utils.utils import unwrap_model @@ -1037,9 +1036,8 @@ def on_validation_batch_end(self, context: PhaseContext) -> None: score = loss_tuple[self._idx_loss_tuple] # IN CONTRARY TO METRICS - LOSS VALUES NEED TO BE REDUCES IN DDP - if is_distributed(): - device = next(context.net.parameters()).device - score = distributed_all_reduce_tensor_average(tensor=score.to(device), n=torch.distributed.get_world_size()) + device = next(context.net.parameters()).device + score = maybe_all_reduce(device, score) if self._is_more_extreme(score): self.extreme_score = score @@ -1047,7 +1045,7 @@ def on_validation_batch_end(self, context: PhaseContext) -> None: self.extreme_preds = context.preds self.extreme_targets = context.target - def _init_loss_attributes(self, context: PhaseContext, loss_tuple: tuple): + def _init_loss_attributes(self, context: PhaseContext): if self.metric_name not in context.loss_logging_items_names: raise ValueError(f"{self.metric_name} not a validation metric, loss or loss component.") self._idx_loss_tuple = context.loss_logging_items_names.index(self.metric_name) @@ -1057,12 +1055,7 @@ def on_validation_loader_end(self, context: PhaseContext) -> None: if context.epoch % self.freq == 0: images_to_save = self.process_extreme_batch() # - if is_distributed(): - rank = get_rank() - output_container = [None for _ in range(get_world_size())] - gather_object(images_to_save, output_container if rank == 0 else None, dst=0) - if rank == 0: - images_to_save = np.concatenate(output_container, 0) + images_to_save = maybe_all_gather(images_to_save) if not context.ddp_silent_mode: context.sg_logger.add_images(tag=self._tag, images=images_to_save, global_step=context.epoch) diff --git a/src/super_gradients/training/utils/distributed_training_utils.py b/src/super_gradients/training/utils/distributed_training_utils.py index d3e12f8365..3fdb7caef4 100755 --- a/src/super_gradients/training/utils/distributed_training_utils.py +++ b/src/super_gradients/training/utils/distributed_training_utils.py @@ -4,10 +4,12 @@ from typing import List, Tuple from contextlib import contextmanager +import numpy as np import torch import torch.nn as nn from torch import distributed as dist from torch.cuda.amp import autocast +from torch.distributed import get_rank, gather_object from torch.distributed.elastic.multiprocessing import Std from torch.distributed.elastic.multiprocessing.errors import record from torch.distributed.launcher.api import LaunchConfig, elastic_launch @@ -412,3 +414,35 @@ def __init__(self): ">>> setup_device(multi_gpu=MultiGPUMode.DISTRIBUTED_DATA_PARALLEL, num_gpus=...)" ) super().__init__(self.message) + + +def maybe_all_reduce(device: str, tensor: torch.Tensor) -> torch.Tensor: + """ + When in DDP- mean-reduces tensor from all devices. + When not in DDP - returns the input tensor. + + :param device: + :param tensor: + :return: + """ + if is_distributed(): + tensor = distributed_all_reduce_tensor_average(tensor=tensor.to(device), n=torch.distributed.get_world_size()) + return tensor + + +def maybe_all_gather(tensor: torch.Tensor) -> torch.Tensor: + """ + When in DDP- gathers tensor from all devices to rank 0. Returns the gathered tensor on rank 0 and the + Ingathered one on other ranks. + When not in DDP - returns the input tensor. + + :param tensor: torch.Tensor, the local rank's tensor to gather + :return: torch.Tensor, the gathered (or original if not in DDP) tensor. + """ + if is_distributed(): + rank = get_rank() + output_container = [None for _ in range(get_world_size())] + gather_object(tensor, output_container if rank == 0 else None, dst=0) + if rank == 0: + tensor = np.concatenate(output_container, 0) + return tensor From 6344f834a381dcf0a043151c665816fdf78097ec Mon Sep 17 00:00:00 2001 From: shayaharon Date: Wed, 19 Jul 2023 16:59:46 +0300 Subject: [PATCH 4/9] refactoring, docs --- .../training/metrics/segmentation_metrics.py | 7 + .../training/utils/callbacks/callbacks.py | 157 ++++++++++++------ .../utils/distributed_training_utils.py | 21 +-- 3 files changed, 121 insertions(+), 64 deletions(-) diff --git a/src/super_gradients/training/metrics/segmentation_metrics.py b/src/super_gradients/training/metrics/segmentation_metrics.py index d934bc1d27..a2409e46f8 100755 --- a/src/super_gradients/training/metrics/segmentation_metrics.py +++ b/src/super_gradients/training/metrics/segmentation_metrics.py @@ -331,6 +331,13 @@ def update(self, preds, target: torch.Tensor): super().update(preds=preds, target=target) +@register_metric("DIOU") +class DIOU(IoU): + def compute(self): + diou = super(DIOU, self).compute() + return {"diou": diou, "diou_minus": -1 * diou} + + @register_metric(Metrics.DICE) class Dice(torchmetrics.JaccardIndex): """ diff --git a/src/super_gradients/training/utils/callbacks/callbacks.py b/src/super_gradients/training/utils/callbacks/callbacks.py index 98ce8bd092..adfc1a13cf 100644 --- a/src/super_gradients/training/utils/callbacks/callbacks.py +++ b/src/super_gradients/training/utils/callbacks/callbacks.py @@ -14,11 +14,13 @@ import torch from deprecated import deprecated from torch.utils.data import DataLoader -from torchmetrics import MetricCollection +from torchmetrics import MetricCollection, Metric from super_gradients.common.abstractions.abstract_logger import get_logger +from super_gradients.common.decorators.factory_decorator import resolve_param from super_gradients.common.environment.ddp_utils import multi_process_safe from super_gradients.common.environment.device_utils import device_config +from super_gradients.common.factories.metrics_factory import MetricsFactory from super_gradients.common.plugins.deci_client import DeciClient from super_gradients.common.registry.registry import register_lr_scheduler, register_lr_warmup, register_callback, LR_SCHEDULERS_CLS_DICT, TORCH_LR_SCHEDULERS from super_gradients.common.object_names import LRSchedulers, LRWarmups, Callbacks @@ -26,7 +28,7 @@ from super_gradients.training.utils import get_param from super_gradients.training.utils.callbacks.base_callbacks import PhaseCallback, PhaseContext, Phase, Callback from super_gradients.training.utils.detection_utils import DetectionVisualization, DetectionPostPredictionCallback -from super_gradients.training.utils.distributed_training_utils import maybe_all_reduce, maybe_all_gather +from super_gradients.training.utils.distributed_training_utils import maybe_all_reduce_tensor_average, maybe_all_gather_np_images from super_gradients.training.utils.segmentation_utils import BinarySegmentationVisualization from super_gradients.common.environment.checkpoints_dir_utils import get_project_checkpoints_dir_path from super_gradients.training.utils.utils import unwrap_model @@ -964,16 +966,22 @@ class ExtremeBatchCaseVisualizationCallback(Callback, ABC): Images are saved with training_hyperparams.sg_logger. - :param metric_name: str,will be the metric which the model checkpoint will be saved according to, and can be set to any - of the following: + :param metric: Metric, will be the metric which is monitored. - a metric name (str) of one of the metric objects from the training_hyperparams.valid_metrics_list + :param metric_component_name: In case metric returns multiple values (as Mapping), + the value at metric.compute()[metric_component_name] will be the one monitored. - a "component_name" if some metric in valid_metrics_list has an attribute component_names. In such cas it - is a list referring to the names of each entry in the output metric (torch tensor of size n). + :param loss_name: str, loss_name corresponfing to the 'criterion' passed through training_params in Trainer.train(...). + Monitoring loss follows the same logic as metric_to_watch in Trainer.train(..), when watching the loss and should be: - one of "loss_logging_items_names" i.e which will correspond to an item returned during the - loss function's forward pass (see loss docs in Trainer.train(..)). + if hasattr(criterion, "component_names") and criterion.forward(..) returns a tuple: + "/". + + If a single item is returned rather then a tuple: + . + + When there is no such attributesand criterion.forward(..) returns a tuple: + "/"Loss_" :param max: bool, Whether to take the batch corresponding to the max value of the metric/loss or the minimum (default=False). @@ -983,9 +991,29 @@ class ExtremeBatchCaseVisualizationCallback(Callback, ABC): Inheritors should implement process_extreme_batch which returns an image, as an np.array (uint8) with shape BCHW. """ - def __init__(self, metric_name: str, max: bool = False, freq: int = 1): - self.metric_name = metric_name - self.metric = None + @resolve_param("metric", MetricsFactory()) + def __init__( + self, + metric: Optional[Union[Metric, Mapping, str]] = None, + metric_component_name: Optional[str] = None, + loss_name: Optional[str] = None, + max: bool = False, + freq: int = 1, + ): + super(ExtremeBatchCaseVisualizationCallback, self).__init__() + + if (metric and loss_name) or (metric is None and loss_name is None): + raise RuntimeError("Must pass exactly one of: loss, metric != None") + + self._set_tag_attr(loss_name, max, metric, metric_component_name) + self.metric = metric + if self.metric: + self.metric = MetricCollection(self.metric) + self.metric.to(device_config.device) + + self.metric_component_name = metric_component_name + + self.loss_name = loss_name self.max = max self.freq = freq self.extreme_score = -1 * np.inf if max else np.inf @@ -996,48 +1024,51 @@ def __init__(self, metric_name: str, max: bool = False, freq: int = 1): self._first_call = True self._idx_loss_tuple = None - self._tag = f"max_{self.metric_name}_batch" if self.max else f"min_{self.metric_name}_batch" - super(ExtremeBatchCaseVisualizationCallback, self).__init__() + def _set_tag_attr(self, loss_name, max, metric, metric_component_name): + if metric_component_name: + monitored_val_name = metric_component_name + elif metric: + monitored_val_name = metric.__class__.__name__ + else: + monitored_val_name = loss_name + self._tag = f"max_{monitored_val_name}_batch" if max else f"min_{monitored_val_name}_batch" @abstractmethod - def process_extreme_batch(self) -> np.array: - raise NotImplementedError - - def on_training_start(self, context: PhaseContext) -> None: + def process_extreme_batch(self) -> np.ndarray: """ - On train start we set the metric (if the metric_name does not corresponf to a loss). - :param context: Phase context - :return: + This method is called right before adding the images to the in SGLoggger (inside the on_validation_loader_end call). + It should process self.extreme_batch, self.extreme_preds and self.extreme_targets and output the images, as np.ndarrray. + Output should be of shape N,3,H,W and uint8. + :return: images to save, np.ndarray """ - if not hasattr(context.valid_metrics, self.metric_name): - for metric_name, metric in context.valid_metrics.items(): - if hasattr(metric, "greater_component_is_better") and self.metric_name in metric.greater_component_is_better.keys(): - # WRAP METRIC WITH METRIC COLLECTION TO FILTER ONLY THE NEEDED ARGUMENTS FOR THE METRIC UPDATE - self.metric = MetricCollection(copy.deepcopy(metric)) - self.metric.to(device_config.device) - else: - self.metric = MetricCollection(copy.deepcopy(getattr(context.valid_metrics, self.metric_name))) - self.metric.to(device_config.device) + raise NotImplementedError def on_validation_batch_end(self, context: PhaseContext) -> None: if context.epoch % self.freq == 0: # FOR METRIC OBJECTS, RESET THEM AND COMPUTE SCORE ONLY ON BATCH. if self.metric is not None: - self.metric.reset() self.metric.update(**context.__dict__) - score = self.metric.compute()[self.metric_name] + score = self.metric.compute() + if self.metric_component_name is not None: + if not isinstance(score, Mapping) or (isinstance(score, Mapping) and self.metric_component_name not in score.keys()): + raise RuntimeError( + f"metric_component_name: {self.metric_component_name} is not a component " + f"of the monitored metric: {self.metric.__class__.__name__}" + ) + score = score[self.metric_component_name] + self.metric.reset() else: - # FOR LOSS VALUES, GET THE RIGHT COMPONENT, DERRIVE IT ON THE FIRST PASS + # FOR LOSS VALUES, GET THE RIGHT COMPONENT, DERIVE IT ON THE FIRST PASS loss_tuple = context.loss_log_items if self._first_call: - self._init_loss_attributes(context, loss_tuple) + self._init_loss_attributes(context) score = loss_tuple[self._idx_loss_tuple] # IN CONTRARY TO METRICS - LOSS VALUES NEED TO BE REDUCES IN DDP device = next(context.net.parameters()).device - score = maybe_all_reduce(device, score) + score = maybe_all_reduce_tensor_average(device, score) if self._is_more_extreme(score): self.extreme_score = score @@ -1046,16 +1077,15 @@ def on_validation_batch_end(self, context: PhaseContext) -> None: self.extreme_targets = context.target def _init_loss_attributes(self, context: PhaseContext): - if self.metric_name not in context.loss_logging_items_names: - raise ValueError(f"{self.metric_name} not a validation metric, loss or loss component.") + if self.loss_name not in context.loss_logging_items_names: + raise ValueError(f"{self.loss_name} not a loss or loss component.") self._idx_loss_tuple = context.loss_logging_items_names.index(self.metric_name) self._first_call = False def on_validation_loader_end(self, context: PhaseContext) -> None: if context.epoch % self.freq == 0: images_to_save = self.process_extreme_batch() - # - images_to_save = maybe_all_gather(images_to_save) + images_to_save = maybe_all_gather_np_images(images_to_save) if not context.ddp_silent_mode: context.sg_logger.add_images(tag=self._tag, images=images_to_save, global_step=context.epoch) @@ -1092,26 +1122,39 @@ class ExtremeBatchSegVisualizationCallback(ExtremeBatchCaseVisualizationCallback ... "phase_callbacks": [ExtremeBatchSegVisualizationCallback( - metric_name=IoU' + metrice=IoU(20, ignore_idx=19) max=False ignore_idx=19), ExtremeBatchSegVisualizationCallback( - metric_name="LabelSmoothingCrossEntropyLoss" + loss_name="LabelSmoothingCrossEntropyLoss" max=True ignore_idx=19)] ...} + Example usage in Yaml config: + + training_hyperparams: + phase_callbacks: + - ExtremeBatchSegVisualizationCallback: + loss_name: DiceCEEdgeLoss/aux_loss0 + ignore_idx: 19 - :param metric_name: str,will be the metric which the model checkpoint will be saved according to, and can be set to any - of the following: + :param metric: Metric, will be the metric which is monitored. - a metric name (str) of one of the metric objects from the training_hyperparams.valid_metrics_list + :param metric_component_name: In case metric returns multiple values (as Mapping), + the value at metric.compute()[metric_component_name] will be the one monitored. - a "component_name" if some metric in valid_metrics_list has an attribute component_names. In such cas it - is a list referring to the names of each entry in the output metric (torch tensor of size n). + :param loss_name: str, loss_name corresponfing to the 'criterion' passed through training_params in Trainer.train(...). + Monitoring loss follows the same logic as metric_to_watch in Trainer.train(..), when watching the loss and should be: - one of "loss_logging_items_names" i.e which will correspond to an item returned during the - loss function's forward pass (see loss docs in Trainer.train(..)). + if hasattr(criterion, "component_names") and criterion.forward(..) returns a tuple: + "/". + + If a single item is returned rather then a tuple: + . + + When there is no such attributesand criterion.forward(..) returns a tuple: + "/"Loss_" :param max: bool, Whether to take the batch corresponding to the max value of the metric/loss or the minimum (default=False). @@ -1119,14 +1162,20 @@ class ExtremeBatchSegVisualizationCallback(ExtremeBatchCaseVisualizationCallback :param freq: int, epoch frequency to perform all of the above (default=1). - :param ignore_idx: int, any prediction of a coordinate in the output image, s.t the ground truth of it is this - value will not be colored in green or in red (default=-1). - - """ - def __init__(self, metric_name: str, max: bool = False, freq: int = 1, ignore_idx: int = -1): - super(ExtremeBatchSegVisualizationCallback, self).__init__(metric_name=metric_name, max=max, freq=freq) + def __init__( + self, + metric: Optional[Union[Metric, Mapping, str]] = None, + metric_component_name: Optional[str] = None, + loss_name: Optional[str] = None, + max: bool = False, + freq: int = 1, + ignore_idx: int = -1, + ): + super(ExtremeBatchSegVisualizationCallback, self).__init__( + metric=metric, metric_component_name=metric_component_name, loss_name=loss_name, max=max, freq=freq + ) self.ignore_idx = ignore_idx def process_extreme_batch(self) -> np.array: diff --git a/src/super_gradients/training/utils/distributed_training_utils.py b/src/super_gradients/training/utils/distributed_training_utils.py index 3fdb7caef4..f5cfe89551 100755 --- a/src/super_gradients/training/utils/distributed_training_utils.py +++ b/src/super_gradients/training/utils/distributed_training_utils.py @@ -9,7 +9,7 @@ import torch.nn as nn from torch import distributed as dist from torch.cuda.amp import autocast -from torch.distributed import get_rank, gather_object +from torch.distributed import get_rank, all_gather_object from torch.distributed.elastic.multiprocessing import Std from torch.distributed.elastic.multiprocessing.errors import record from torch.distributed.launcher.api import LaunchConfig, elastic_launch @@ -416,7 +416,7 @@ def __init__(self): super().__init__(self.message) -def maybe_all_reduce(device: str, tensor: torch.Tensor) -> torch.Tensor: +def maybe_all_reduce_tensor_average(device: str, tensor: torch.Tensor) -> torch.Tensor: """ When in DDP- mean-reduces tensor from all devices. When not in DDP - returns the input tensor. @@ -430,19 +430,20 @@ def maybe_all_reduce(device: str, tensor: torch.Tensor) -> torch.Tensor: return tensor -def maybe_all_gather(tensor: torch.Tensor) -> torch.Tensor: +def maybe_all_gather_np_images(image: np.ndarray) -> np.ndarray: """ - When in DDP- gathers tensor from all devices to rank 0. Returns the gathered tensor on rank 0 and the - Ingathered one on other ranks. + When in DDP- gathers images (as np.ndarray objects) from all processes. + Returns the concatenated np.array across dim=0. When not in DDP - returns the input tensor. - :param tensor: torch.Tensor, the local rank's tensor to gather - :return: torch.Tensor, the gathered (or original if not in DDP) tensor. + :param image: np.ndarray, the local rank's tensor to gather + + :return: np.ndarray, the output image as described above """ if is_distributed(): rank = get_rank() output_container = [None for _ in range(get_world_size())] - gather_object(tensor, output_container if rank == 0 else None, dst=0) + all_gather_object(output_container, image) if rank == 0: - tensor = np.concatenate(output_container, 0) - return tensor + image = np.concatenate(output_container, 0) + return image From 2854f3515b3c9e6a51afc43275f76aaffa145fce Mon Sep 17 00:00:00 2001 From: shayaharon Date: Mon, 24 Jul 2023 10:43:56 +0300 Subject: [PATCH 5/9] removed testing metric --- .../training/metrics/segmentation_metrics.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/super_gradients/training/metrics/segmentation_metrics.py b/src/super_gradients/training/metrics/segmentation_metrics.py index a2409e46f8..d934bc1d27 100755 --- a/src/super_gradients/training/metrics/segmentation_metrics.py +++ b/src/super_gradients/training/metrics/segmentation_metrics.py @@ -331,13 +331,6 @@ def update(self, preds, target: torch.Tensor): super().update(preds=preds, target=target) -@register_metric("DIOU") -class DIOU(IoU): - def compute(self): - diou = super(DIOU, self).compute() - return {"diou": diou, "diou_minus": -1 * diou} - - @register_metric(Metrics.DICE) class Dice(torchmetrics.JaccardIndex): """ From 2b8674a4f34b48dae6485f8b4d4dcc1aee57135a Mon Sep 17 00:00:00 2001 From: shayaharon Date: Mon, 24 Jul 2023 10:48:40 +0300 Subject: [PATCH 6/9] removed device arg from maybe all reduce --- src/super_gradients/training/utils/callbacks/callbacks.py | 3 ++- .../training/utils/distributed_training_utils.py | 7 +++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/super_gradients/training/utils/callbacks/callbacks.py b/src/super_gradients/training/utils/callbacks/callbacks.py index adfc1a13cf..0a3f3ff8c7 100644 --- a/src/super_gradients/training/utils/callbacks/callbacks.py +++ b/src/super_gradients/training/utils/callbacks/callbacks.py @@ -1068,7 +1068,8 @@ def on_validation_batch_end(self, context: PhaseContext) -> None: # IN CONTRARY TO METRICS - LOSS VALUES NEED TO BE REDUCES IN DDP device = next(context.net.parameters()).device - score = maybe_all_reduce_tensor_average(device, score) + score.to(device) + score = maybe_all_reduce_tensor_average(score) if self._is_more_extreme(score): self.extreme_score = score diff --git a/src/super_gradients/training/utils/distributed_training_utils.py b/src/super_gradients/training/utils/distributed_training_utils.py index f5cfe89551..3e97caf553 100755 --- a/src/super_gradients/training/utils/distributed_training_utils.py +++ b/src/super_gradients/training/utils/distributed_training_utils.py @@ -416,17 +416,16 @@ def __init__(self): super().__init__(self.message) -def maybe_all_reduce_tensor_average(device: str, tensor: torch.Tensor) -> torch.Tensor: +def maybe_all_reduce_tensor_average(tensor: torch.Tensor) -> torch.Tensor: """ When in DDP- mean-reduces tensor from all devices. When not in DDP - returns the input tensor. - :param device: - :param tensor: + :param tensor:tensor to (maybe) reduce :return: """ if is_distributed(): - tensor = distributed_all_reduce_tensor_average(tensor=tensor.to(device), n=torch.distributed.get_world_size()) + tensor = distributed_all_reduce_tensor_average(tensor=tensor, n=torch.distributed.get_world_size()) return tensor From 0633b5556cee46e5962bcf5bdf6216720293ab91 Mon Sep 17 00:00:00 2001 From: shayaharon Date: Mon, 24 Jul 2023 11:16:29 +0300 Subject: [PATCH 7/9] fixed metrice typo in docs --- src/super_gradients/training/utils/callbacks/callbacks.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/super_gradients/training/utils/callbacks/callbacks.py b/src/super_gradients/training/utils/callbacks/callbacks.py index 0a3f3ff8c7..5c12257dc4 100644 --- a/src/super_gradients/training/utils/callbacks/callbacks.py +++ b/src/super_gradients/training/utils/callbacks/callbacks.py @@ -994,7 +994,7 @@ class ExtremeBatchCaseVisualizationCallback(Callback, ABC): @resolve_param("metric", MetricsFactory()) def __init__( self, - metric: Optional[Union[Metric, Mapping, str]] = None, + metric: Optional[Metric] = None, metric_component_name: Optional[str] = None, loss_name: Optional[str] = None, max: bool = False, @@ -1123,7 +1123,7 @@ class ExtremeBatchSegVisualizationCallback(ExtremeBatchCaseVisualizationCallback ... "phase_callbacks": [ExtremeBatchSegVisualizationCallback( - metrice=IoU(20, ignore_idx=19) + metric=IoU(20, ignore_idx=19) max=False ignore_idx=19), ExtremeBatchSegVisualizationCallback( @@ -1167,7 +1167,7 @@ class ExtremeBatchSegVisualizationCallback(ExtremeBatchCaseVisualizationCallback def __init__( self, - metric: Optional[Union[Metric, Mapping, str]] = None, + metric: Optional[Metric] = None, metric_component_name: Optional[str] = None, loss_name: Optional[str] = None, max: bool = False, From b40599b5cefbf05fbcceb69ac9bceb7a90f9d0c7 Mon Sep 17 00:00:00 2001 From: shayaharon Date: Mon, 24 Jul 2023 11:17:49 +0300 Subject: [PATCH 8/9] loss_name changed to loss_to_monitor --- .../training/utils/callbacks/callbacks.py | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/src/super_gradients/training/utils/callbacks/callbacks.py b/src/super_gradients/training/utils/callbacks/callbacks.py index 5c12257dc4..6246b8bc49 100644 --- a/src/super_gradients/training/utils/callbacks/callbacks.py +++ b/src/super_gradients/training/utils/callbacks/callbacks.py @@ -971,7 +971,7 @@ class ExtremeBatchCaseVisualizationCallback(Callback, ABC): :param metric_component_name: In case metric returns multiple values (as Mapping), the value at metric.compute()[metric_component_name] will be the one monitored. - :param loss_name: str, loss_name corresponfing to the 'criterion' passed through training_params in Trainer.train(...). + :param loss_to_monitor: str, loss_to_monitor corresponfing to the 'criterion' passed through training_params in Trainer.train(...). Monitoring loss follows the same logic as metric_to_watch in Trainer.train(..), when watching the loss and should be: if hasattr(criterion, "component_names") and criterion.forward(..) returns a tuple: @@ -996,16 +996,16 @@ def __init__( self, metric: Optional[Metric] = None, metric_component_name: Optional[str] = None, - loss_name: Optional[str] = None, + loss_to_monitor: Optional[str] = None, max: bool = False, freq: int = 1, ): super(ExtremeBatchCaseVisualizationCallback, self).__init__() - if (metric and loss_name) or (metric is None and loss_name is None): + if (metric and loss_to_monitor) or (metric is None and loss_to_monitor is None): raise RuntimeError("Must pass exactly one of: loss, metric != None") - self._set_tag_attr(loss_name, max, metric, metric_component_name) + self._set_tag_attr(loss_to_monitor, max, metric, metric_component_name) self.metric = metric if self.metric: self.metric = MetricCollection(self.metric) @@ -1013,7 +1013,7 @@ def __init__( self.metric_component_name = metric_component_name - self.loss_name = loss_name + self.loss_to_monitor = loss_to_monitor self.max = max self.freq = freq self.extreme_score = -1 * np.inf if max else np.inf @@ -1025,13 +1025,13 @@ def __init__( self._first_call = True self._idx_loss_tuple = None - def _set_tag_attr(self, loss_name, max, metric, metric_component_name): + def _set_tag_attr(self, loss_to_monitor, max, metric, metric_component_name): if metric_component_name: monitored_val_name = metric_component_name elif metric: monitored_val_name = metric.__class__.__name__ else: - monitored_val_name = loss_name + monitored_val_name = loss_to_monitor self._tag = f"max_{monitored_val_name}_batch" if max else f"min_{monitored_val_name}_batch" @abstractmethod @@ -1078,8 +1078,8 @@ def on_validation_batch_end(self, context: PhaseContext) -> None: self.extreme_targets = context.target def _init_loss_attributes(self, context: PhaseContext): - if self.loss_name not in context.loss_logging_items_names: - raise ValueError(f"{self.loss_name} not a loss or loss component.") + if self.loss_to_monitor not in context.loss_logging_items_names: + raise ValueError(f"{self.loss_to_monitor} not a loss or loss component.") self._idx_loss_tuple = context.loss_logging_items_names.index(self.metric_name) self._first_call = False @@ -1127,7 +1127,7 @@ class ExtremeBatchSegVisualizationCallback(ExtremeBatchCaseVisualizationCallback max=False ignore_idx=19), ExtremeBatchSegVisualizationCallback( - loss_name="LabelSmoothingCrossEntropyLoss" + loss_to_monitor="LabelSmoothingCrossEntropyLoss" max=True ignore_idx=19)] ...} @@ -1137,7 +1137,7 @@ class ExtremeBatchSegVisualizationCallback(ExtremeBatchCaseVisualizationCallback training_hyperparams: phase_callbacks: - ExtremeBatchSegVisualizationCallback: - loss_name: DiceCEEdgeLoss/aux_loss0 + loss_to_monitor: DiceCEEdgeLoss/aux_loss0 ignore_idx: 19 :param metric: Metric, will be the metric which is monitored. @@ -1145,7 +1145,7 @@ class ExtremeBatchSegVisualizationCallback(ExtremeBatchCaseVisualizationCallback :param metric_component_name: In case metric returns multiple values (as Mapping), the value at metric.compute()[metric_component_name] will be the one monitored. - :param loss_name: str, loss_name corresponfing to the 'criterion' passed through training_params in Trainer.train(...). + :param loss_to_monitor: str, loss_to_monitor corresponfing to the 'criterion' passed through training_params in Trainer.train(...). Monitoring loss follows the same logic as metric_to_watch in Trainer.train(..), when watching the loss and should be: if hasattr(criterion, "component_names") and criterion.forward(..) returns a tuple: @@ -1169,13 +1169,13 @@ def __init__( self, metric: Optional[Metric] = None, metric_component_name: Optional[str] = None, - loss_name: Optional[str] = None, + loss_to_monitor: Optional[str] = None, max: bool = False, freq: int = 1, ignore_idx: int = -1, ): super(ExtremeBatchSegVisualizationCallback, self).__init__( - metric=metric, metric_component_name=metric_component_name, loss_name=loss_name, max=max, freq=freq + metric=metric, metric_component_name=metric_component_name, loss_to_monitor=loss_to_monitor, max=max, freq=freq ) self.ignore_idx = ignore_idx From d0730f1ad5fd8536a7796d4d71259f5ff3198096 Mon Sep 17 00:00:00 2001 From: shayaharon Date: Thu, 27 Jul 2023 13:40:23 +0300 Subject: [PATCH 9/9] unit tests added --- .../training/utils/callbacks/callbacks.py | 7 +- tests/deci_core_unit_test_suite_runner.py | 2 + tests/unit_tests/extreme_batch_cb_test.py | 71 +++++++++++++++++++ 3 files changed, 79 insertions(+), 1 deletion(-) create mode 100644 tests/unit_tests/extreme_batch_cb_test.py diff --git a/src/super_gradients/training/utils/callbacks/callbacks.py b/src/super_gradients/training/utils/callbacks/callbacks.py index 6246b8bc49..ceff2b73e1 100644 --- a/src/super_gradients/training/utils/callbacks/callbacks.py +++ b/src/super_gradients/training/utils/callbacks/callbacks.py @@ -1057,7 +1057,12 @@ def on_validation_batch_end(self, context: PhaseContext) -> None: f"of the monitored metric: {self.metric.__class__.__name__}" ) score = score[self.metric_component_name] + elif len(score) > 1: + raise RuntimeError(f"returned multiple values from {self.metric} but no metric_component_name has been passed to __init__.") + else: + score = score.pop(list(score.keys())[0]) self.metric.reset() + else: # FOR LOSS VALUES, GET THE RIGHT COMPONENT, DERIVE IT ON THE FIRST PASS @@ -1080,7 +1085,7 @@ def on_validation_batch_end(self, context: PhaseContext) -> None: def _init_loss_attributes(self, context: PhaseContext): if self.loss_to_monitor not in context.loss_logging_items_names: raise ValueError(f"{self.loss_to_monitor} not a loss or loss component.") - self._idx_loss_tuple = context.loss_logging_items_names.index(self.metric_name) + self._idx_loss_tuple = context.loss_logging_items_names.index(self.loss_to_monitor) self._first_call = False def on_validation_loader_end(self, context: PhaseContext) -> None: diff --git a/tests/deci_core_unit_test_suite_runner.py b/tests/deci_core_unit_test_suite_runner.py index c1d37602ca..68d5191d3a 100644 --- a/tests/deci_core_unit_test_suite_runner.py +++ b/tests/deci_core_unit_test_suite_runner.py @@ -27,6 +27,7 @@ from tests.unit_tests.detection_utils_test import TestDetectionUtils from tests.unit_tests.detection_dataset_test import DetectionDatasetTest from tests.unit_tests.export_onnx_test import TestModelsONNXExport +from tests.unit_tests.extreme_batch_cb_test import ExtremeBatchSanityTest from tests.unit_tests.load_checkpoint_test import LoadCheckpointTest from tests.unit_tests.local_ckpt_head_replacement_test import LocalCkptHeadReplacementTest from tests.unit_tests.max_batches_loop_break_test import MaxBatchesLoopBreakTest @@ -145,6 +146,7 @@ def _add_modules_to_unit_tests_suite(self): self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestPostPredictionCallback)) self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestSegmentationMetricsMultipleIgnored)) self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TrainWithTorchSchedulerTest)) + self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(ExtremeBatchSanityTest)) def _add_modules_to_end_to_end_tests_suite(self): """ diff --git a/tests/unit_tests/extreme_batch_cb_test.py b/tests/unit_tests/extreme_batch_cb_test.py new file mode 100644 index 0000000000..902571eb57 --- /dev/null +++ b/tests/unit_tests/extreme_batch_cb_test.py @@ -0,0 +1,71 @@ +import unittest +from super_gradients import Trainer +from super_gradients.common.object_names import Models +from super_gradients.training import models +from super_gradients.training.dataloaders.dataloaders import segmentation_test_dataloader +from super_gradients.training.losses.ddrnet_loss import DDRNetLoss +from super_gradients.training.metrics import IoU +from super_gradients.training.utils.callbacks.callbacks import ExtremeBatchSegVisualizationCallback + + +# Helper method to set up Trainer and model with common parameters +def setup_trainer_and_model(experiment_name: str): + trainer = Trainer(experiment_name) + model = models.get(Models.DDRNET_23, arch_params={"use_aux_heads": True}, pretrained_weights="cityscapes") + return trainer, model + + +class DummyIOU(IoU): + """ + Metric for testing the segmentation callback works with compound metrics + """ + + def compute(self): + diou = super(DummyIOU, self).compute() + return {"diou": diou, "diou_minus": -1 * diou} + + +class ExtremeBatchSanityTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.training_params = { + "max_epochs": 3, + "initial_lr": 1e-2, + "loss": DDRNetLoss(), + "lr_mode": "poly", + "ema": True, + "average_best_models": True, + "optimizer": "SGD", + "mixed_precision": False, + "optimizer_params": {"weight_decay": 5e-4, "momentum": 0.9}, + "load_opt_params": False, + "train_metrics_list": [IoU(5)], + "valid_metrics_list": [IoU(5)], + "metric_to_watch": "IoU", + "greater_metric_to_watch_is_better": True, + } + + def test_segmentation_extreme_batch_with_metric_sanity(self): + trainer, model = setup_trainer_and_model("test_segmentation_extreme_batch_with_metric_sanity") + self.training_params["phase_callbacks"] = [ExtremeBatchSegVisualizationCallback(IoU(5))] + trainer.train( + model=model, training_params=self.training_params, train_loader=segmentation_test_dataloader(), valid_loader=segmentation_test_dataloader() + ) + + def test_segmentation_extreme_batch_with_compound_metric_sanity(self): + trainer, model = setup_trainer_and_model("test_segmentation_extreme_batch_with_compound_metric_sanity") + self.training_params["phase_callbacks"] = [ExtremeBatchSegVisualizationCallback(DummyIOU(5), metric_component_name="diou_minus")] + trainer.train( + model=model, training_params=self.training_params, train_loader=segmentation_test_dataloader(), valid_loader=segmentation_test_dataloader() + ) + + def test_segmentation_extreme_batch_with_loss_sanity(self): + trainer, model = setup_trainer_and_model("test_segmentation_extreme_batch_with_loss_sanity") + self.training_params["phase_callbacks"] = [ExtremeBatchSegVisualizationCallback(loss_to_monitor="DDRNetLoss/aux_loss1")] + trainer.train( + model=model, training_params=self.training_params, train_loader=segmentation_test_dataloader(), valid_loader=segmentation_test_dataloader() + ) + + +if __name__ == "__main__": + unittest.main()