From 99def2b3ee2c208859c62997a987633c12cf0d25 Mon Sep 17 00:00:00 2001 From: Ashwin Vaidya Date: Fri, 3 Jun 2022 11:58:15 +0200 Subject: [PATCH] Add metrics configuration callback to benchmarking --- anomalib/utils/sweep/helpers/callbacks.py | 23 +++++++++++++++++++++-- tools/benchmarking/benchmark.py | 2 +- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/anomalib/utils/sweep/helpers/callbacks.py b/anomalib/utils/sweep/helpers/callbacks.py index e09267c91e..576d1dd17a 100644 --- a/anomalib/utils/sweep/helpers/callbacks.py +++ b/anomalib/utils/sweep/helpers/callbacks.py @@ -15,14 +15,16 @@ # and limitations under the License. -from typing import List +from typing import List, Union +from omegaconf import DictConfig, ListConfig from pytorch_lightning import Callback +from anomalib.utils.callbacks import MetricsConfigurationCallback from anomalib.utils.callbacks.timer import TimerCallback -def get_sweep_callbacks() -> List[Callback]: +def get_sweep_callbacks(config: Union[ListConfig, DictConfig]) -> List[Callback]: """Gets callbacks relevant to sweep. Args: @@ -32,5 +34,22 @@ def get_sweep_callbacks() -> List[Callback]: List[Callback]: List of callbacks """ callbacks: List[Callback] = [TimerCallback()] + # Add metric configuration to the model via MetricsConfigurationCallback + image_metric_names = config.metrics.image if "image" in config.metrics.keys() else None + pixel_metric_names = config.metrics.pixel if "pixel" in config.metrics.keys() else None + image_threshold = ( + config.metrics.threshold.image_default if "image_default" in config.metrics.threshold.keys() else None + ) + pixel_threshold = ( + config.metrics.threshold.pixel_default if "pixel_default" in config.metrics.threshold.keys() else None + ) + metrics_callback = MetricsConfigurationCallback( + config.metrics.threshold.adaptive, + image_threshold, + pixel_threshold, + image_metric_names, + pixel_metric_names, + ) + callbacks.append(metrics_callback) return callbacks diff --git a/tools/benchmarking/benchmark.py b/tools/benchmarking/benchmark.py index 753b2e8af2..e5f87a413f 100644 --- a/tools/benchmarking/benchmark.py +++ b/tools/benchmarking/benchmark.py @@ -100,7 +100,7 @@ def get_single_model_metrics(model_config: Union[DictConfig, ListConfig], openvi datamodule = get_datamodule(model_config) model = get_model(model_config) - callbacks = get_sweep_callbacks() + callbacks = get_sweep_callbacks(model_config) trainer = Trainer(**model_config.trainer, logger=None, callbacks=callbacks)