Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

πŸš€ Add support for MVTec LOCO dataset and sPRO metric #1686

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
c5c6bc7
add FROM_DIR option to `val split mode` to support a provided val dir…
willyfh Jan 13, 2024
4c76092
add a conditional check for the FROM_DIR option of val split mode
willyfh Jan 13, 2024
ecc7169
add the mvtec loco ad dataset classes
willyfh Jan 13, 2024
9d89a68
add the default config file for mvtec loco ad dataset
willyfh Jan 13, 2024
a312be0
update initialization files to include MVTec LOCO dataset
willyfh Jan 13, 2024
37a6f6b
remove unnecessary Path conversion
willyfh Jan 13, 2024
a733bf7
add mvtec_loco.yaml to the readme documentation of configs
willyfh Jan 13, 2024
38ad1f4
add dummy image generation for mvtec loco dataset
willyfh Jan 13, 2024
05e3558
add unit test for mvtec loco dataset
willyfh Jan 13, 2024
519780a
update changelog to include the addition of mvtec loco dataset
willyfh Jan 13, 2024
c3c6a39
add mvtec loco dataset to the sphinx-based docs
willyfh Jan 14, 2024
ca91b8e
fix the malformed table
willyfh Jan 14, 2024
37f4dbc
binarize the masks and avoid the possibility of the merge_mask is None
willyfh Jan 14, 2024
3559a7c
Merge the masks using sum operation without binarization
willyfh Jan 21, 2024
83539b4
override getitem method to handle binarization and to add additional …
willyfh Jan 21, 2024
ddd33c5
Add saturation config to the datamodule
willyfh Jan 21, 2024
3af6eeb
Update the saturation config on the metrics based on the loaded confi…
willyfh Jan 21, 2024
53f2297
add masks as a keyword args to the update method of the AnomalibMetri…
willyfh Jan 21, 2024
3d00d69
Shorten the comments to solve ruff issues
willyfh Jan 21, 2024
f024e91
Add sPro metric implementation
willyfh Jan 21, 2024
9b8ca3b
Change the saturation threshold to tensor
willyfh Jan 21, 2024
29aaf46
Handle case with only background/normal images in scoring
willyfh Jan 21, 2024
f432753
rename spro metric and change the default value of saturation_config …
willyfh Jan 31, 2024
7d348d8
add unit test for spro metric
willyfh Jan 31, 2024
5597bfc
fix pre-commit issues
willyfh Jan 31, 2024
7b02863
handle file not found error when loading saturation config
willyfh Jan 31, 2024
6048a08
validate path before processing
willyfh Feb 1, 2024
e237347
update changelog with new PR
willyfh Feb 1, 2024
f9b67b8
Update src/anomalib/data/image/mvtec_loco.py
willyfh Feb 6, 2024
63bf8de
Update src/anomalib/data/image/mvtec_loco.py
willyfh Feb 6, 2024
0310e1b
Update tests/helpers/data.py
willyfh Feb 6, 2024
4eb2ec3
change assert to raise error
willyfh Feb 10, 2024
f3bccb8
return list of masks instead of merging the multiple masks from the d…
willyfh Feb 10, 2024
ac1ecb1
collate masks as a list of tensor to avoid stack error due to unequal…
willyfh Feb 10, 2024
3bbc750
update spro to handle list of masks and remove the _ args
willyfh Feb 10, 2024
e8e7360
update unit test to use list of masks as the target
willyfh Feb 10, 2024
ccef95c
update type and docstring of spro_score function
willyfh Feb 10, 2024
2cdeb82
remove _saturation_config attribute from metric collection module
willyfh Feb 10, 2024
d2cbcf3
remove unnecessary lines
willyfh Feb 11, 2024
829289b
add unit test to make sure the `mask` is binary
willyfh Feb 11, 2024
53aa07f
add warning when the saturation threshold is larger than the defect area
willyfh Feb 11, 2024
d9a2333
Move the loading process of saturation config from dataset to metric
willyfh Feb 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

### Added

- Add support for MVTec LOCO AD dataset and sPRO metric by @willyfh in https://github.com/openvinotoolkit/anomalib/pull/1686

### Changed

- Changed default inference device to AUTO in https://github.com/openvinotoolkit/anomalib/pull/1534
Expand Down
8 changes: 8 additions & 0 deletions docs/source/markdown/guides/reference/data/image/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ Learn more about Kolektor dataset.
Learn more about MVTec 2D dataset
:::

:::{grid-item-card} MVTec LOCO
:link: ./mvtec_loco
:link-type: doc

Learn more about MVTec LOCO dataset
:::

:::{grid-item-card} Visa
:link: ./visa
:link-type: doc
Expand All @@ -47,5 +54,6 @@ Learn more about Visa dataset.
./folder
./kolektor
./mvtec
./mvtec_loco
./visa
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# MVTec LOCO Data

```{eval-rst}
.. automodule:: anomalib.data.image.mvtec_loco
:members:
:show-inheritance:
```
28 changes: 24 additions & 4 deletions src/anomalib/callbacks/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from lightning.pytorch.utilities.types import STEP_OUTPUT

from anomalib import TaskType
from anomalib.metrics import AnomalibMetricCollection, create_metric_collection
from anomalib.metrics import SPRO, AnomalibMetricCollection, create_metric_collection
from anomalib.models import AnomalyModule

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -67,8 +67,7 @@ def setup(
pl_module (AnomalyModule): Anomalib Model that inherits pl LightningModule.
stage (str | None, optional): fit, validate, test or predict. Defaults to None.
"""
del trainer, stage # These variables are not used.

del stage, trainer # this variable is not used.
image_metric_names = [] if self.image_metric_names is None else self.image_metric_names
if isinstance(image_metric_names, str):
image_metric_names = [image_metric_names]
Expand Down Expand Up @@ -182,16 +181,37 @@ def _update_metrics(
image_metric.update(output["pred_scores"], output["label"].int())
if "mask" in output and "anomaly_maps" in output:
pixel_metric.to(self.device)
pixel_metric.update(torch.squeeze(output["anomaly_maps"]), torch.squeeze(output["mask"].int()))
if "masks" in output:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for this PR but we need to be careful with these names as we might forget why we have mask and masks as two separate keys. And it might not be clear to new users

self._update_pixel_metrics(pixel_metric, output)
else:
pixel_metric.update(torch.squeeze(output["anomaly_maps"]), torch.squeeze(output["mask"].int()))

def _outputs_to_device(self, output: STEP_OUTPUT) -> STEP_OUTPUT | dict[str, Any]:
if isinstance(output, dict):
for key, value in output.items():
output[key] = self._outputs_to_device(value)
elif isinstance(output, torch.Tensor):
output = output.to(self.device)
elif isinstance(output, list):
for i, value in enumerate(output):
output[i] = self._outputs_to_device(value)
return output

def _update_pixel_metrics(self, pixel_metric: AnomalibMetricCollection, output: STEP_OUTPUT) -> None:
"""Handle metric updates when the SPRO metric is used alongside other pixel-level metrics."""
update = False
for name, metric in pixel_metric.items(copy_state=False):
if isinstance(metric, SPRO):
metric.update(torch.squeeze(output["anomaly_maps"]), output["masks"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be handled inside the metric so that we don't need metric-specific logic here?

Copy link
Contributor Author

@willyfh willyfh Feb 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@djdameln this approach was actually to address this comment #1686 (comment) . I provided a more detailed explanation in this comment #1686 (comment)

My initial approach was as follows:

pixel_metric.update( 
         torch.squeeze(output["anomaly_maps"]), 
         torch.squeeze(output["mask"].int()), 
         masks=torch.squeeze(output["masks"]) if "masks" in output else None, 
     ) 

My initial approach can handle all metrics (including SPRO and others) without metric-specific logic in the callbacks.metrics.py file. However, the problem is that we need an additional argument, _:

def update(self, predictions: torch.Tensor, _: torch.Tensor, masks: torch.Tensor) -> None:

So, I changed the solution to avoid using _ by incorporating metric-specific logic. Is there a better idea to avoid metric-specific logic while also avoiding the use of _ at the same time?

else:
logger.warning(
f"Metric {name} may not be suitable for a dataset with the region separated "
"in multiple ground-truth masks.",
)
metric.update(torch.squeeze(output["anomaly_maps"]), torch.squeeze(output["mask"].int()))
update = True
pixel_metric.set_update_called(update)

@staticmethod
def _log_metrics(pl_module: AnomalyModule) -> None:
"""Log computed performance metrics."""
Expand Down
7 changes: 6 additions & 1 deletion src/anomalib/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,12 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
parser.add_argument("--visualization.show", type=bool, default=False)
parser.add_argument("--task", type=TaskType, default=TaskType.SEGMENTATION)
parser.add_argument("--metrics.image", type=list[str] | str | None, default=["F1Score", "AUROC"])
parser.add_argument("--metrics.pixel", type=list[str] | str | None, default=None, required=False)
parser.add_argument(
"--metrics.pixel",
type=list[str] | str | dict[str, dict[str, Any]] | None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@djdameln This change is needed to address this comment #1686 (comment) .

I provided a more additional explanation here #1686 (comment)

Basically, we want to use class_path and init_args, so that we can pass saturation_config to the SPRO metric directly. via this config:

metrics:
  pixel:
    SPRO:
      class_path: anomalib.metrics.SPRO
      init_args:
        saturation_config: ./datasets/MVTec_LOCO/pushpins/defects_config.json

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you make the same typing change to metrics.image?

default=None,
required=False,
)
parser.add_argument("--metrics.threshold", type=BaseThreshold, default="F1AdaptiveThreshold")
parser.add_argument("--logging.log_graph", type=bool, help="Log the model to the logger", default=False)
if hasattr(parser, "subcommand") and parser.subcommand != "predict": # Predict also accepts str and Path inputs
Expand Down
3 changes: 2 additions & 1 deletion src/anomalib/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from .base import AnomalibDataModule, AnomalibDataset
from .depth import DepthDataFormat, Folder3D, MVTec3D
from .image import BTech, Folder, ImageDataFormat, Kolektor, MVTec, Visa
from .image import BTech, Folder, ImageDataFormat, Kolektor, MVTec, MVTecLoco, Visa
from .predict import PredictDataset
from .video import Avenue, ShanghaiTech, UCSDped, VideoDataFormat

Expand Down Expand Up @@ -62,6 +62,7 @@ def get_datamodule(config: DictConfig | ListConfig) -> AnomalibDataModule:
"Kolektor",
"MVTec",
"MVTec3D",
"MVTecLoco",
"Avenue",
"UCSDped",
"ShanghaiTech",
Expand Down
11 changes: 10 additions & 1 deletion src/anomalib/data/base/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
def collate_fn(batch: list) -> dict[str, Any]:
"""Collate bounding boxes as lists.

Bounding boxes are collated as a list of tensors, while the default collate function is used for all other entries.
Bounding boxes and `masks` (not `mask`) are collated as a list of tensors. If `masks` is exist,
the `mask_path` is also collated as a list since each element in the batch could be unequal.
For all other entries, the default collate function is used.

Args:
batch (List): list of items in the batch where len(batch) is equal to the batch size.
Expand All @@ -40,6 +42,10 @@ def collate_fn(batch: list) -> dict[str, Any]:
if "boxes" in elem:
# collate boxes as list
out_dict["boxes"] = [item.pop("boxes") for item in batch]
if "masks" in elem:
# collate masks and mask_path as list
out_dict["masks"] = [item.pop("masks") for item in batch]
out_dict["mask_path"] = [item.pop("mask_path") for item in batch]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed? The default collate function should already collate the mask paths as a list.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@djdameln In the MVTec LOCO dataset, a single image could be paired with multiple mask paths stored as a list, where the number of mask paths for each image can vary. When using the default collate function, it will raise an error because the Torch stack function cannot stack items with different lengths. The code is needed to avoid such error.

+---+---------------+-------+---------+-------------------------+-----------------------------+-------------+
| | path | split | label | image_path | mask_path | label_index |
+===+===============+=======+=========+===============+=======================================+=============+
| 0 | datasets/name | test | defect | path/to/image/file.png | [path/to/masks/file.png] | 1 |
+---+---------------+-------+---------+-------------------------+-----------------------------+-------------+

# collate other data normally
out_dict.update({key: default_collate([item[key] for item in batch]) for key in elem})
return out_dict
Expand Down Expand Up @@ -164,6 +170,9 @@ def _create_val_split(self) -> None:
# converted from random training sample
self.train_data, normal_val_data = random_split(self.train_data, self.val_split_ratio, seed=self.seed)
self.val_data = SyntheticAnomalyDataset.from_dataset(normal_val_data)
elif self.val_split_mode == ValSplitMode.FROM_DIR:
# the val_data is prepared in subclass
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could lead to problems with thresholding and normalization, as the validation set of MVTec Loco only contains normal images by default. For accurate computation of the adaptive threshold and normalization statistics we would have to add some anomalous images here. Not sure what would be the best approach here though, since we do want to stick to the predefined splits for reproducibility.

Copy link
Contributor Author

@willyfh willyfh Feb 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we do want to stick to the predefined splits for reproducibility.

@djdameln Yeah, actually, this is why the FROM_DIR option was implemented in the first place, to use the predefined validation split as it is. However, if a user doesn't want to use this option for some reason, or perhaps for the reason you mentioned, can they freely change it to same_as_test or synthetic?

elif self.val_split_mode != ValSplitMode.NONE:
msg = f"Unknown validation split mode: {self.val_split_mode}"
raise ValueError(msg)
Expand Down
4 changes: 3 additions & 1 deletion src/anomalib/data/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .folder import Folder
from .kolektor import Kolektor
from .mvtec import MVTec
from .mvtec_loco import MVTecLoco
from .visa import Visa


Expand All @@ -21,11 +22,12 @@ class ImageDataFormat(str, Enum):

MVTEC = "mvtec"
MVTEC_3D = "mvtec_3d"
MVTEC_LOCO = "mvtec_loco"
BTECH = "btech"
KOLEKTOR = "kolektor"
FOLDER = "folder"
FOLDER_3D = "folder_3d"
VISA = "visa"


__all__ = ["BTech", "Folder", "Kolektor", "MVTec", "Visa"]
__all__ = ["BTech", "Folder", "Kolektor", "MVTec", "MVTecLoco", "Visa"]
Loading
Loading