Skip to content

Commit

Permalink
Move the loading process of saturation config from dataset to metric
Browse files Browse the repository at this point in the history
Signed-off-by: Willy Fitra Hendria <[email protected]>
  • Loading branch information
willyfh committed Feb 11, 2024
1 parent 53aa07f commit d9a2333
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 78 deletions.
13 changes: 6 additions & 7 deletions src/anomalib/callbacks/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def setup(
pl_module (AnomalyModule): Anomalib Model that inherits pl LightningModule.
stage (str | None, optional): fit, validate, test or predict. Defaults to None.
"""
del stage # this variable is not used.
del stage, trainer # this variable is not used.
image_metric_names = [] if self.image_metric_names is None else self.image_metric_names
if isinstance(image_metric_names, str):
image_metric_names = [image_metric_names]
Expand Down Expand Up @@ -97,8 +97,6 @@ def setup(
else:
pl_module.pixel_metrics = create_metric_collection(pixel_metric_names, "pixel_")
self._set_threshold(pl_module)
if hasattr(trainer.datamodule, "saturation_config"):
self._set_saturation_config(pl_module, trainer.datamodule.saturation_config)

def on_validation_epoch_start(
self,
Expand Down Expand Up @@ -173,9 +171,6 @@ def _set_threshold(self, pl_module: AnomalyModule) -> None:
pl_module.image_metrics.set_threshold(pl_module.image_threshold.value.item())
pl_module.pixel_metrics.set_threshold(pl_module.pixel_threshold.value.item())

def _set_saturation_config(self, pl_module: AnomalyModule, saturation_config: dict[int, Any]) -> None:
pl_module.pixel_metrics.set_saturation_config(saturation_config)

def _update_metrics(
self,
image_metric: AnomalibMetricCollection,
Expand Down Expand Up @@ -205,10 +200,14 @@ def _outputs_to_device(self, output: STEP_OUTPUT) -> STEP_OUTPUT | dict[str, Any
def _update_pixel_metrics(self, pixel_metric: AnomalibMetricCollection, output: STEP_OUTPUT) -> None:
"""Handle metric updates when the SPRO metric is used alongside other pixel-level metrics."""
update = False
for metric in pixel_metric.values(copy_state=False):
for name, metric in pixel_metric.items(copy_state=False):
if isinstance(metric, SPRO):
metric.update(torch.squeeze(output["anomaly_maps"]), output["masks"])
else:
logger.warning(
f"Metric {name} may not be suitable for a dataset with the region separated "
"in multiple ground-truth masks.",
)
metric.update(torch.squeeze(output["anomaly_maps"]), torch.squeeze(output["mask"].int()))
update = True
pixel_metric.set_update_called(update)
Expand Down
7 changes: 6 additions & 1 deletion src/anomalib/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,12 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
parser.add_argument("--visualization.show", type=bool, default=False)
parser.add_argument("--task", type=TaskType, default=TaskType.SEGMENTATION)
parser.add_argument("--metrics.image", type=list[str] | str | None, default=["F1Score", "AUROC"])
parser.add_argument("--metrics.pixel", type=list[str] | str | None, default=None, required=False)
parser.add_argument(
"--metrics.pixel",
type=list[str] | str | dict[str, dict[str, Any]] | None,
default=None,
required=False,
)
parser.add_argument("--metrics.threshold", type=BaseThreshold, default="F1AdaptiveThreshold")
parser.add_argument("--logging.log_graph", type=bool, help="Log the model to the logger", default=False)
if hasattr(parser, "subcommand") and parser.subcommand != "predict": # Predict also accepts str and Path inputs
Expand Down
47 changes: 0 additions & 47 deletions src/anomalib/data/image/mvtec_loco.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,9 @@
in: International Journal of Computer Vision (IJCV) 130, 947-969, 2022, DOI: 10.1007/s11263-022-01578-9
"""

import json
import logging
from collections.abc import Sequence
from pathlib import Path
from typing import Any

import albumentations as A # noqa: N812
import cv2
Expand Down Expand Up @@ -64,46 +62,6 @@
"splicing_connectors",
)

SATURATION_CONFIG_FILENAME = "defects_config.json"


def load_saturation_config(config_path: str | Path) -> dict[int, Any] | None:
"""Load saturation configurations from a JSON file.
Args:
config_path (str | Path): Path to the saturation configuration file.
Returns:
Dict | None: A dictionary with pixel values as keys and the corresponding configurations as values.
Return None if the config file is not found.
Example JSON format in the file:
[
{
"defect_name": "1_additional_pushpin",
"pixel_value": 255,
"saturation_threshold": 6300,
"relative_saturation": false
},
{
"defect_name": "2_additional_pushpins",
"pixel_value": 254,
"saturation_threshold": 12600,
"relative_saturation": false
},
...
]
"""
try:
config_path = validate_path(config_path)
with Path.open(config_path) as file:
configs = json.load(file)
# Create a dictionary with pixel values as keys
return {conf["pixel_value"]: conf for conf in configs}
except FileNotFoundError:
logger.warning("The saturation config file %s does not exist. Returning None.", config_path)
return None


def make_mvtec_loco_dataset(
root: str | Path,
Expand Down Expand Up @@ -448,10 +406,8 @@ def __init__(
val_split_ratio=val_split_ratio,
seed=seed,
)
self.saturation_config: dict[int, Any] | None
self.root = Path(root)
self.category = Path(category)
self.saturation_config = {}

transform_train = get_transforms(
config=transform_config_train,
Expand Down Expand Up @@ -549,6 +505,3 @@ def _setup(self, _stage: str | None = None) -> None:

self._create_test_split()
self._create_val_split()

saturation_path = self.root / self.category / SATURATION_CONFIG_FILENAME
self.saturation_config = load_saturation_config(saturation_path)
11 changes: 0 additions & 11 deletions src/anomalib/metrics/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,6 @@ def set_threshold(self, threshold_value: float) -> None:
if hasattr(metric, "threshold"):
metric.threshold = threshold_value

def set_saturation_config(self, saturation_config: dict) -> None:
"""Update the saturation config values for all metrics that have the saturation config attribute."""
for name, metric in self.items():
if hasattr(metric, "saturation_config"):
metric.saturation_config = saturation_config
else:
logger.warning(
f"Metric {name} may not be suitable for a dataset with the region separated "
"in multiple ground-truth masks.",
)

def set_update_called(self, val: bool) -> None:
"""Set the flag indicating whether the update method has been called."""
self._update_called = val
Expand Down
58 changes: 51 additions & 7 deletions src/anomalib/metrics/spro.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import json
import logging
from pathlib import Path
from typing import Any

import torch
from torchmetrics import Metric

from anomalib.data.utils import validate_path

logger = logging.getLogger(__name__)


Expand All @@ -20,7 +25,7 @@ class SPRO(Metric):
Args:
threshold (float): Threshold used to binarize the predictions.
Defaults to ``0.5``.
saturation_config (dict): Saturations configuration for each label (pixel value) as the keys.
saturation_config (str | Path): Path to the saturation configuration file.
Defaults: ``None`` (which the score is equivalent to PRO metric, but with the 'region' are
separated by mask files.
kwargs: Additional arguments to the TorchMetrics base class.
Expand Down Expand Up @@ -50,10 +55,15 @@ class SPRO(Metric):
"""

def __init__(self, threshold: float = 0.5, saturation_config: dict | None = None, **kwargs) -> None:
def __init__(self, threshold: float = 0.5, saturation_config: str | Path | None = None, **kwargs) -> None:
super().__init__(**kwargs)
self.threshold = threshold
self.saturation_config = saturation_config
self.saturation_config = load_saturation_config(saturation_config) if saturation_config is not None else None
if self.saturation_config is None:
logger.warning(
"The saturation_config attribute is empty, the threshold is set to the defect area."
"This is equivalent to PRO metric but with the 'region' are separated by mask files",
)
self.add_state("score", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

Expand Down Expand Up @@ -158,13 +168,47 @@ def spro_score(
saturation_threshold = defect_area
else:
# Handle case when saturation_config is empty
logger.warning(
"The saturation_config attribute is empty, the threshold is set to the defect area."
"This is equivalent to PRO metric but with the 'region' are separated by mask files",
)
saturation_threshold = defect_area

# Update score with minimum of true_pos/saturation_threshold and 1.0
score += torch.minimum(true_pos / saturation_threshold, torch.tensor(1.0))
total += 1
return score, total


def load_saturation_config(config_path: str | Path) -> dict[int, Any] | None:
"""Load saturation configurations from a JSON file.
Args:
config_path (str | Path): Path to the saturation configuration file.
Returns:
Dict | None: A dictionary with pixel values as keys and the corresponding configurations as values.
Return None if the config file is not found.
Example JSON format in the config file of MVTec LOCO dataset:
[
{
"defect_name": "1_additional_pushpin",
"pixel_value": 255,
"saturation_threshold": 6300,
"relative_saturation": false
},
{
"defect_name": "2_additional_pushpins",
"pixel_value": 254,
"saturation_threshold": 12600,
"relative_saturation": false
},
...
]
"""
try:
config_path = validate_path(config_path)
with Path.open(config_path) as file:
configs = json.load(file)
# Create a dictionary with pixel values as keys
return {conf["pixel_value"]: conf for conf in configs}
except FileNotFoundError:
logger.warning("The saturation config file %s does not exist. Returning None.", config_path)
return None
23 changes: 18 additions & 5 deletions tests/unit/metrics/test_spro.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,33 @@
# Copyright (C) 2023-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import json
import pathlib
import tempfile

import torch

from anomalib.metrics.spro import SPRO


def test_spro() -> None:
"""Checks if SPRO metric computes the score utilizing the given saturation configs."""
saturation_config = {
255: {
saturation_config = [
{
"pixel_value": 255,
"saturation_threshold": 10,
"relative_saturation": False,
},
254: {
{
"pixel_value": 254,
"saturation_threshold": 0.5,
"relative_saturation": True,
},
}
]

with tempfile.NamedTemporaryFile(suffix=".json", mode="w", delete=False) as f:
json.dump(saturation_config, f)
saturation_config_json = f.name

masks = [
torch.Tensor(
Expand Down Expand Up @@ -60,11 +70,14 @@ def test_spro() -> None:
targets_wo_saturation = [1.0, 0.625, 0.5, 0.375, 0.0, 0.0]
for threshold, target, target_wo_saturation in zip(thresholds, targets, targets_wo_saturation, strict=True):
# test using saturation_cofig
spro = SPRO(threshold=threshold, saturation_config=saturation_config)
spro = SPRO(threshold=threshold, saturation_config=saturation_config_json)
spro.update(preds, masks)
assert spro.compute() == target

# test without saturation_config
spro_wo_saturaton = SPRO(threshold=threshold)
spro_wo_saturaton.update(preds, masks)
assert spro_wo_saturaton.compute() == target_wo_saturation

# Remove the temporary config file
pathlib.Path(saturation_config_json).unlink()

0 comments on commit d9a2333

Please sign in to comment.