diff --git a/src/anomalib/callbacks/metrics.py b/src/anomalib/callbacks/metrics.py index 9b8b0e3bda..9e2bcdb7ec 100644 --- a/src/anomalib/callbacks/metrics.py +++ b/src/anomalib/callbacks/metrics.py @@ -67,7 +67,7 @@ def setup( pl_module (AnomalyModule): Anomalib Model that inherits pl LightningModule. stage (str | None, optional): fit, validate, test or predict. Defaults to None. """ - del stage # this variable is not used. + del stage, trainer # this variable is not used. image_metric_names = [] if self.image_metric_names is None else self.image_metric_names if isinstance(image_metric_names, str): image_metric_names = [image_metric_names] @@ -97,8 +97,6 @@ def setup( else: pl_module.pixel_metrics = create_metric_collection(pixel_metric_names, "pixel_") self._set_threshold(pl_module) - if hasattr(trainer.datamodule, "saturation_config"): - self._set_saturation_config(pl_module, trainer.datamodule.saturation_config) def on_validation_epoch_start( self, @@ -173,9 +171,6 @@ def _set_threshold(self, pl_module: AnomalyModule) -> None: pl_module.image_metrics.set_threshold(pl_module.image_threshold.value.item()) pl_module.pixel_metrics.set_threshold(pl_module.pixel_threshold.value.item()) - def _set_saturation_config(self, pl_module: AnomalyModule, saturation_config: dict[int, Any]) -> None: - pl_module.pixel_metrics.set_saturation_config(saturation_config) - def _update_metrics( self, image_metric: AnomalibMetricCollection, @@ -205,10 +200,14 @@ def _outputs_to_device(self, output: STEP_OUTPUT) -> STEP_OUTPUT | dict[str, Any def _update_pixel_metrics(self, pixel_metric: AnomalibMetricCollection, output: STEP_OUTPUT) -> None: """Handle metric updates when the SPRO metric is used alongside other pixel-level metrics.""" update = False - for metric in pixel_metric.values(copy_state=False): + for name, metric in pixel_metric.items(copy_state=False): if isinstance(metric, SPRO): metric.update(torch.squeeze(output["anomaly_maps"]), output["masks"]) else: + logger.warning( + f"Metric {name} may not be suitable for a dataset with the region separated " + "in multiple ground-truth masks.", + ) metric.update(torch.squeeze(output["anomaly_maps"]), torch.squeeze(output["mask"].int())) update = True pixel_metric.set_update_called(update) diff --git a/src/anomalib/cli/cli.py b/src/anomalib/cli/cli.py index d56e6d578b..736acbd830 100644 --- a/src/anomalib/cli/cli.py +++ b/src/anomalib/cli/cli.py @@ -143,7 +143,12 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: parser.add_argument("--visualization.show", type=bool, default=False) parser.add_argument("--task", type=TaskType, default=TaskType.SEGMENTATION) parser.add_argument("--metrics.image", type=list[str] | str | None, default=["F1Score", "AUROC"]) - parser.add_argument("--metrics.pixel", type=list[str] | str | None, default=None, required=False) + parser.add_argument( + "--metrics.pixel", + type=list[str] | str | dict[str, dict[str, Any]] | None, + default=None, + required=False, + ) parser.add_argument("--metrics.threshold", type=BaseThreshold, default="F1AdaptiveThreshold") parser.add_argument("--logging.log_graph", type=bool, help="Log the model to the logger", default=False) if hasattr(parser, "subcommand") and parser.subcommand != "predict": # Predict also accepts str and Path inputs diff --git a/src/anomalib/data/image/mvtec_loco.py b/src/anomalib/data/image/mvtec_loco.py index 65dbbbb726..8864ada0cd 100644 --- a/src/anomalib/data/image/mvtec_loco.py +++ b/src/anomalib/data/image/mvtec_loco.py @@ -16,11 +16,9 @@ in: International Journal of Computer Vision (IJCV) 130, 947-969, 2022, DOI: 10.1007/s11263-022-01578-9 """ -import json import logging from collections.abc import Sequence from pathlib import Path -from typing import Any import albumentations as A # noqa: N812 import cv2 @@ -64,46 +62,6 @@ "splicing_connectors", ) -SATURATION_CONFIG_FILENAME = "defects_config.json" - - -def load_saturation_config(config_path: str | Path) -> dict[int, Any] | None: - """Load saturation configurations from a JSON file. - - Args: - config_path (str | Path): Path to the saturation configuration file. - - Returns: - Dict | None: A dictionary with pixel values as keys and the corresponding configurations as values. - Return None if the config file is not found. - - Example JSON format in the file: - [ - { - "defect_name": "1_additional_pushpin", - "pixel_value": 255, - "saturation_threshold": 6300, - "relative_saturation": false - }, - { - "defect_name": "2_additional_pushpins", - "pixel_value": 254, - "saturation_threshold": 12600, - "relative_saturation": false - }, - ... - ] - """ - try: - config_path = validate_path(config_path) - with Path.open(config_path) as file: - configs = json.load(file) - # Create a dictionary with pixel values as keys - return {conf["pixel_value"]: conf for conf in configs} - except FileNotFoundError: - logger.warning("The saturation config file %s does not exist. Returning None.", config_path) - return None - def make_mvtec_loco_dataset( root: str | Path, @@ -448,10 +406,8 @@ def __init__( val_split_ratio=val_split_ratio, seed=seed, ) - self.saturation_config: dict[int, Any] | None self.root = Path(root) self.category = Path(category) - self.saturation_config = {} transform_train = get_transforms( config=transform_config_train, @@ -549,6 +505,3 @@ def _setup(self, _stage: str | None = None) -> None: self._create_test_split() self._create_val_split() - - saturation_path = self.root / self.category / SATURATION_CONFIG_FILENAME - self.saturation_config = load_saturation_config(saturation_path) diff --git a/src/anomalib/metrics/collection.py b/src/anomalib/metrics/collection.py index 27399041dc..47c17a3a44 100644 --- a/src/anomalib/metrics/collection.py +++ b/src/anomalib/metrics/collection.py @@ -25,17 +25,6 @@ def set_threshold(self, threshold_value: float) -> None: if hasattr(metric, "threshold"): metric.threshold = threshold_value - def set_saturation_config(self, saturation_config: dict) -> None: - """Update the saturation config values for all metrics that have the saturation config attribute.""" - for name, metric in self.items(): - if hasattr(metric, "saturation_config"): - metric.saturation_config = saturation_config - else: - logger.warning( - f"Metric {name} may not be suitable for a dataset with the region separated " - "in multiple ground-truth masks.", - ) - def set_update_called(self, val: bool) -> None: """Set the flag indicating whether the update method has been called.""" self._update_called = val diff --git a/src/anomalib/metrics/spro.py b/src/anomalib/metrics/spro.py index 9128ded884..9269ce5188 100644 --- a/src/anomalib/metrics/spro.py +++ b/src/anomalib/metrics/spro.py @@ -3,11 +3,16 @@ # Copyright (C) 2022 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +import json import logging +from pathlib import Path +from typing import Any import torch from torchmetrics import Metric +from anomalib.data.utils import validate_path + logger = logging.getLogger(__name__) @@ -20,7 +25,7 @@ class SPRO(Metric): Args: threshold (float): Threshold used to binarize the predictions. Defaults to ``0.5``. - saturation_config (dict): Saturations configuration for each label (pixel value) as the keys. + saturation_config (str | Path): Path to the saturation configuration file. Defaults: ``None`` (which the score is equivalent to PRO metric, but with the 'region' are separated by mask files. kwargs: Additional arguments to the TorchMetrics base class. @@ -50,10 +55,15 @@ class SPRO(Metric): """ - def __init__(self, threshold: float = 0.5, saturation_config: dict | None = None, **kwargs) -> None: + def __init__(self, threshold: float = 0.5, saturation_config: str | Path | None = None, **kwargs) -> None: super().__init__(**kwargs) self.threshold = threshold - self.saturation_config = saturation_config + self.saturation_config = load_saturation_config(saturation_config) if saturation_config is not None else None + if self.saturation_config is None: + logger.warning( + "The saturation_config attribute is empty, the threshold is set to the defect area." + "This is equivalent to PRO metric but with the 'region' are separated by mask files", + ) self.add_state("score", default=torch.tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") @@ -158,13 +168,47 @@ def spro_score( saturation_threshold = defect_area else: # Handle case when saturation_config is empty - logger.warning( - "The saturation_config attribute is empty, the threshold is set to the defect area." - "This is equivalent to PRO metric but with the 'region' are separated by mask files", - ) saturation_threshold = defect_area # Update score with minimum of true_pos/saturation_threshold and 1.0 score += torch.minimum(true_pos / saturation_threshold, torch.tensor(1.0)) total += 1 return score, total + + +def load_saturation_config(config_path: str | Path) -> dict[int, Any] | None: + """Load saturation configurations from a JSON file. + + Args: + config_path (str | Path): Path to the saturation configuration file. + + Returns: + Dict | None: A dictionary with pixel values as keys and the corresponding configurations as values. + Return None if the config file is not found. + + Example JSON format in the config file of MVTec LOCO dataset: + [ + { + "defect_name": "1_additional_pushpin", + "pixel_value": 255, + "saturation_threshold": 6300, + "relative_saturation": false + }, + { + "defect_name": "2_additional_pushpins", + "pixel_value": 254, + "saturation_threshold": 12600, + "relative_saturation": false + }, + ... + ] + """ + try: + config_path = validate_path(config_path) + with Path.open(config_path) as file: + configs = json.load(file) + # Create a dictionary with pixel values as keys + return {conf["pixel_value"]: conf for conf in configs} + except FileNotFoundError: + logger.warning("The saturation config file %s does not exist. Returning None.", config_path) + return None diff --git a/tests/unit/metrics/test_spro.py b/tests/unit/metrics/test_spro.py index f6f5826419..37ee536433 100644 --- a/tests/unit/metrics/test_spro.py +++ b/tests/unit/metrics/test_spro.py @@ -3,6 +3,10 @@ # Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +import json +import pathlib +import tempfile + import torch from anomalib.metrics.spro import SPRO @@ -10,16 +14,22 @@ def test_spro() -> None: """Checks if SPRO metric computes the score utilizing the given saturation configs.""" - saturation_config = { - 255: { + saturation_config = [ + { + "pixel_value": 255, "saturation_threshold": 10, "relative_saturation": False, }, - 254: { + { + "pixel_value": 254, "saturation_threshold": 0.5, "relative_saturation": True, }, - } + ] + + with tempfile.NamedTemporaryFile(suffix=".json", mode="w", delete=False) as f: + json.dump(saturation_config, f) + saturation_config_json = f.name masks = [ torch.Tensor( @@ -60,7 +70,7 @@ def test_spro() -> None: targets_wo_saturation = [1.0, 0.625, 0.5, 0.375, 0.0, 0.0] for threshold, target, target_wo_saturation in zip(thresholds, targets, targets_wo_saturation, strict=True): # test using saturation_cofig - spro = SPRO(threshold=threshold, saturation_config=saturation_config) + spro = SPRO(threshold=threshold, saturation_config=saturation_config_json) spro.update(preds, masks) assert spro.compute() == target @@ -68,3 +78,6 @@ def test_spro() -> None: spro_wo_saturaton = SPRO(threshold=threshold) spro_wo_saturaton.update(preds, masks) assert spro_wo_saturaton.compute() == target_wo_saturation + + # Remove the temporary config file + pathlib.Path(saturation_config_json).unlink()