Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for segmentation extreme batch cases #1282

Merged
merged 21 commits into from
Aug 1, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
5c45ee3
tested version
shaydeci Jul 16, 2023
14008cf
Merge branch 'master' into feature/SG-901_worst_samples_od_cb
shaydeci Jul 16, 2023
9bb4c9b
changed base to abc and abstractmethod
shaydeci Jul 17, 2023
392c52f
Merge remote-tracking branch 'origin/feature/SG-901_worst_samples_od_…
shaydeci Jul 17, 2023
f22ec75
Merge branch 'master' into feature/SG-901_worst_samples_od_cb
shaydeci Jul 17, 2023
bbd2f79
comments wip
shaydeci Jul 17, 2023
221cba8
Merge remote-tracking branch 'origin/feature/SG-901_worst_samples_od_…
shaydeci Jul 19, 2023
30d86aa
Merge remote-tracking branch 'origin/master' into feature/SG-901_wors…
shaydeci Jul 19, 2023
6344f83
refactoring, docs
shaydeci Jul 19, 2023
6db796d
Merge branch 'master' into feature/SG-901_worst_samples_od_cb
shaydeci Jul 20, 2023
2854f35
removed testing metric
shaydeci Jul 24, 2023
7b06812
Merge remote-tracking branch 'origin/feature/SG-901_worst_samples_od_…
shaydeci Jul 24, 2023
2b8674a
removed device arg from maybe all reduce
shaydeci Jul 24, 2023
0633b55
fixed metrice typo in docs
shaydeci Jul 24, 2023
b40599b
loss_name changed to loss_to_monitor
shaydeci Jul 24, 2023
3671565
Merge branch 'master' into feature/SG-901_worst_samples_od_cb
BloodAxe Jul 27, 2023
d0730f1
unit tests added
shaydeci Jul 27, 2023
39b36ac
Merge remote-tracking branch 'origin/feature/SG-901_worst_samples_od_…
shaydeci Jul 27, 2023
74827eb
Merge remote-tracking branch 'origin/master' into feature/SG-901_wors…
shaydeci Aug 1, 2023
432a184
Merge branch 'master' into feature/SG-901_worst_samples_od_cb
shaydeci Aug 1, 2023
2b17f02
Merge branch 'master' into feature/SG-901_worst_samples_od_cb
shaydeci Aug 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/super_gradients/training/sg_trainer/sg_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ def _train_epoch(self, context: PhaseContext, silent_mode: bool = False) -> tupl
# COMPUTE THE LOSS FOR BACK PROP + EXTRA METRICS COMPUTED DURING THE LOSS FORWARD PASS
loss, loss_log_items = self._get_losses(outputs, targets)

context.update_context(preds=outputs, loss_log_items=loss_log_items)
context.update_context(preds=outputs, loss_log_items=loss_log_items, loss_logging_items_names=self.loss_logging_items_names)
self.phase_callback_handler.on_train_batch_loss_end(context)

if not self.ddp_silent_mode and batch_idx == 0:
Expand Down Expand Up @@ -1316,6 +1316,7 @@ def forward(self, inputs, targets):
metric_to_watch=self.metric_to_watch,
device=device_config.device,
ema_model=self.ema_model,
valid_metrics=self.valid_metrics,
)
self.phase_callback_handler.on_training_start(context)

Expand Down Expand Up @@ -1986,6 +1987,7 @@ def evaluate(

lr_warmup_epochs = self.training_params.lr_warmup_epochs if self.training_params else None
context = PhaseContext(
net=self.net,
epoch=epoch,
metrics_compute_fn=metrics,
loss_avg_meter=loss_avg_meter,
Expand All @@ -1995,6 +1997,7 @@ def evaluate(
sg_logger=self.sg_logger,
train_loader=self.train_loader,
valid_loader=self.valid_loader,
loss_logging_items_names=self.loss_logging_items_names,
)

with tqdm(data_loader, bar_format="{l_bar}{bar:10}{r_bar}", dynamic_ncols=True, disable=silent_mode) as progress_bar_data_loader:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(
metric_to_watch=None,
valid_metrics=None,
ema_model=None,
loss_logging_items_names=None,
):
self.epoch = epoch
self.batch_idx = batch_idx
Expand Down Expand Up @@ -82,6 +83,7 @@ def __init__(
self.metric_to_watch = metric_to_watch
self.valid_metrics = valid_metrics
self.ema_model = ema_model
self.loss_logging_items_names = loss_logging_items_names

def update_context(self, **kwargs):
for attr, attr_val in kwargs.items():
Expand Down
209 changes: 208 additions & 1 deletion src/super_gradients/training/utils/callbacks/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,25 @@
import onnxruntime
import torch
from deprecated import deprecated
from torch.distributed import gather_object, get_rank
from torch.utils.data import DataLoader
from torchmetrics import MetricCollection

from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.common.environment.ddp_utils import multi_process_safe
from super_gradients.common.environment.ddp_utils import multi_process_safe, is_distributed
from super_gradients.common.environment.device_utils import device_config
from super_gradients.common.plugins.deci_client import DeciClient
from super_gradients.common.registry.registry import register_lr_scheduler, register_lr_warmup, register_callback, LR_SCHEDULERS_CLS_DICT, TORCH_LR_SCHEDULERS
from super_gradients.common.object_names import LRSchedulers, LRWarmups, Callbacks
from super_gradients.common.sg_loggers.time_units import GlobalBatchStepNumber, EpochNumber
from super_gradients.training.utils import get_param
from super_gradients.training.utils.callbacks.base_callbacks import PhaseCallback, PhaseContext, Phase, Callback
from super_gradients.training.utils.detection_utils import DetectionVisualization, DetectionPostPredictionCallback
from super_gradients.training.utils.distributed_training_utils import distributed_all_reduce_tensor_average, get_world_size
from super_gradients.training.utils.segmentation_utils import BinarySegmentationVisualization
from super_gradients.common.environment.checkpoints_dir_utils import get_project_checkpoints_dir_path
from super_gradients.training.utils.utils import unwrap_model
from torchvision.utils import draw_segmentation_masks

logger = get_logger(__name__)

Expand Down Expand Up @@ -948,3 +953,205 @@ def create_lr_scheduler_callback(
raise ValueError(f"Unknown lr_mode: {lr_mode}")

return sg_lr_callback


class ExtremeBatchCaseVisualizationCallback(Callback):
"""
ExtremeBatchCaseVisualizationCallback

A base class for visualizing worst/best validation batches in an epoch
according to some metric or loss value, with Full DDP support.

Images are saved with training_hyperparams.sg_logger.

:param metric_name: str,will be the metric which the model checkpoint will be saved according to, and can be set to any
of the following:

a metric name (str) of one of the metric objects from the training_hyperparams.valid_metrics_list

a "component_name" if some metric in valid_metrics_list has an attribute component_names. In such cas it
is a list referring to the names of each entry in the output metric (torch tensor of size n).

one of "loss_logging_items_names" i.e which will correspond to an item returned during the
loss function's forward pass (see loss docs in Trainer.train(..)).

:param max: bool, Whether to take the batch corresponding to the max value of the metric/loss or
the minimum (default=False).

:param freq: int, epoch frequency to perform all of the above (default=1).

Inheritors should implement process_extreme_batch which returns an image, as an np.array (uint8) with shape BCHW.
"""

def __init__(self, metric_name: str, max: bool = False, freq: int = 1):
self.metric_name = metric_name
self.metric = None
self.max = max
self.freq = freq
self.extreme_score = -1 * np.inf if max else np.inf

self.extreme_batch = None
self.extreme_preds = None
self.extreme_targets = None

self._first_call = True
self._idx_loss_tuple = None
self._tag = f"max_{self.metric_name}_batch" if self.max else f"min_{self.metric_name}_batch"

super(ExtremeBatchCaseVisualizationCallback, self).__init__()

def process_extreme_batch(self) -> np.array:
raise NotImplementedError

def on_training_start(self, context: PhaseContext) -> None:
"""
On train start we set the metric (if the metric_name does not corresponf to a loss).
:param context: Phase context
:return:
"""
if not hasattr(context.valid_metrics, self.metric_name):
for metric_name, metric in context.valid_metrics.items():
if hasattr(metric, "greater_component_is_better") and self.metric_name in metric.greater_component_is_better.keys():
# WRAP METRIC WITH METRIC COLLECTION TO FILTER ONLY THE NEEDED ARGUMENTS FOR THE METRIC UPDATE
self.metric = MetricCollection(copy.deepcopy(metric))
self.metric.to(device_config.device)
else:
self.metric = MetricCollection(copy.deepcopy(getattr(context.valid_metrics, self.metric_name)))
self.metric.to(device_config.device)

def on_validation_batch_end(self, context: PhaseContext) -> None:
if context.epoch % self.freq == 0:
# FOR METRIC OBJECTS, RESET THEM AND COMPUTE SCORE ONLY ON BATCH.
if self.metric is not None:
self.metric.reset()
self.metric.update(**context.__dict__)
score = self.metric.compute()[self.metric_name]
else:

# FOR LOSS VALUES, GET THE RIGHT COMPONENT, DERRIVE IT ON THE FIRST PASS
loss_tuple = context.loss_log_items
if self._first_call:
self._init_loss_attributes(context, loss_tuple)
score = loss_tuple[self._idx_loss_tuple]

# IN CONTRARY TO METRICS - LOSS VALUES NEED TO BE REDUCES IN DDP
if is_distributed():
device = next(context.net.parameters()).device
score = distributed_all_reduce_tensor_average(tensor=score.to(device), n=torch.distributed.get_world_size())

if self._is_more_extreme(score):
self.extreme_score = score
self.extreme_batch = context.inputs
self.extreme_preds = context.preds
self.extreme_targets = context.target

def _init_loss_attributes(self, context: PhaseContext, loss_tuple: tuple):
if self.metric_name not in context.loss_logging_items_names:
raise ValueError(f"{self.metric_name} not a validation metric, loss or loss component.")
self._idx_loss_tuple = context.loss_logging_items_names.index(self.metric_name)
self._first_call = False

def on_validation_loader_end(self, context: PhaseContext) -> None:
if context.epoch % self.freq == 0:
images_to_save = self.process_extreme_batch()
#
if is_distributed():
rank = get_rank()
output_container = [None for _ in range(get_world_size())]
gather_object(images_to_save, output_container if rank == 0 else None, dst=0)
if rank == 0:
images_to_save = np.concatenate(output_container, 0)
if not context.ddp_silent_mode:
context.sg_logger.add_images(tag=self._tag, images=images_to_save, global_step=context.epoch)

self._reset()

def _reset(self):
self.extreme_score = -1 * np.inf if self.max else np.inf
self.extreme_batch = None
self.extreme_preds = None
self.extreme_targets = None
if self.metric is not None:
self.metric.reset()

def _is_more_extreme(self, score: float) -> bool:
if self.max:
return self.extreme_score < score
else:
return self.extreme_score > score


@register_callback("ExtremeBatchSegVisualizationCallback")
class ExtremeBatchSegVisualizationCallback(ExtremeBatchCaseVisualizationCallback):
"""
ExtremeBatchSegVisualizationCallback

Visualizes worst/best batch in an epoch, for segmentation.
Assumes context.preds in validation is a score tensor of shape BCHW, or a tuple whose first item is one.

True predictions will be marked with green, false ones with red.

Example usage in training_params definition:

training_hyperparams ={
...
"phase_callbacks":
[ExtremeBatchSegVisualizationCallback(
metric_name=IoU'
max=False
ignore_idx=19),
ExtremeBatchSegVisualizationCallback(
metric_name="LabelSmoothingCrossEntropyLoss"
max=True
ignore_idx=19)]
...}


:param metric_name: str,will be the metric which the model checkpoint will be saved according to, and can be set to any
of the following:

a metric name (str) of one of the metric objects from the training_hyperparams.valid_metrics_list

a "component_name" if some metric in valid_metrics_list has an attribute component_names. In such cas it
is a list referring to the names of each entry in the output metric (torch tensor of size n).

one of "loss_logging_items_names" i.e which will correspond to an item returned during the
loss function's forward pass (see loss docs in Trainer.train(..)).

:param max: bool, Whether to take the batch corresponding to the max value of the metric/loss or
the minimum (default=False).

:param freq: int, epoch frequency to perform all of the above (default=1).


:param ignore_idx: int, any prediction of a coordinate in the output image, s.t the ground truth of it is this
value will not be colored in green or in red (default=-1).


"""

def __init__(self, metric_name: str, max: bool = False, freq: int = 1, ignore_idx: int = -1):
super(ExtremeBatchSegVisualizationCallback, self).__init__(metric_name=metric_name, max=max, freq=freq)
self.ignore_idx = ignore_idx

def process_extreme_batch(self) -> np.array:
inputs = self.extreme_batch
inputs -= inputs.min()
inputs /= inputs.max()
inputs *= 255
inputs = inputs.to(torch.uint8)
preds = self.extreme_preds
if isinstance(preds, tuple):
preds = preds[0]
preds = preds.argmax(1)
p_mask = preds == self.extreme_targets
n_mask = preds != self.extreme_targets
p_mask[self.extreme_targets == self.ignore_idx] = False
n_mask[self.extreme_targets == self.ignore_idx] = False
overlay = torch.cat([p_mask.unsqueeze(1), n_mask.unsqueeze(1)], 1)
colors = ["green", "red"]
images_to_save = []
for i in range(len(inputs)):
images_to_save.append(draw_segmentation_masks(inputs[i].cpu(), overlay[i], colors=colors, alpha=0.4).detach().numpy())
images_to_save = np.array(images_to_save)
return images_to_save