-
Notifications
You must be signed in to change notification settings - Fork 692
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
π Add support for MVTec LOCO dataset and sPRO metric #1686
Changes from 38 commits
c5c6bc7
4c76092
ecc7169
9d89a68
a312be0
37a6f6b
a733bf7
38ad1f4
05e3558
519780a
c3c6a39
ca91b8e
37f4dbc
3559a7c
83539b4
ddd33c5
3af6eeb
53f2297
3d00d69
f024e91
9b8ca3b
29aaf46
f432753
7d348d8
5597bfc
7b02863
6048a08
e237347
f9b67b8
63bf8de
0310e1b
4eb2ec3
f3bccb8
ac1ecb1
3bbc750
e8e7360
ccef95c
2cdeb82
d2cbcf3
829289b
53aa07f
d9a2333
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# MVTec LOCO Data | ||
|
||
```{eval-rst} | ||
.. automodule:: anomalib.data.image.mvtec_loco | ||
:members: | ||
:show-inheritance: | ||
``` |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,7 +13,7 @@ | |
from lightning.pytorch.utilities.types import STEP_OUTPUT | ||
|
||
from anomalib import TaskType | ||
from anomalib.metrics import AnomalibMetricCollection, create_metric_collection | ||
from anomalib.metrics import SPRO, AnomalibMetricCollection, create_metric_collection | ||
from anomalib.models import AnomalyModule | ||
|
||
logger = logging.getLogger(__name__) | ||
|
@@ -67,8 +67,7 @@ def setup( | |
pl_module (AnomalyModule): Anomalib Model that inherits pl LightningModule. | ||
stage (str | None, optional): fit, validate, test or predict. Defaults to None. | ||
""" | ||
del trainer, stage # These variables are not used. | ||
|
||
del stage # this variable is not used. | ||
image_metric_names = [] if self.image_metric_names is None else self.image_metric_names | ||
if isinstance(image_metric_names, str): | ||
image_metric_names = [image_metric_names] | ||
|
@@ -98,6 +97,8 @@ def setup( | |
else: | ||
pl_module.pixel_metrics = create_metric_collection(pixel_metric_names, "pixel_") | ||
self._set_threshold(pl_module) | ||
if hasattr(trainer.datamodule, "saturation_config"): | ||
self._set_saturation_config(pl_module, trainer.datamodule.saturation_config) | ||
|
||
def on_validation_epoch_start( | ||
self, | ||
|
@@ -172,6 +173,9 @@ def _set_threshold(self, pl_module: AnomalyModule) -> None: | |
pl_module.image_metrics.set_threshold(pl_module.image_threshold.value.item()) | ||
pl_module.pixel_metrics.set_threshold(pl_module.pixel_threshold.value.item()) | ||
|
||
def _set_saturation_config(self, pl_module: AnomalyModule, saturation_config: dict[int, Any]) -> None: | ||
pl_module.pixel_metrics.set_saturation_config(saturation_config) | ||
|
||
def _update_metrics( | ||
self, | ||
image_metric: AnomalibMetricCollection, | ||
|
@@ -182,16 +186,33 @@ def _update_metrics( | |
image_metric.update(output["pred_scores"], output["label"].int()) | ||
if "mask" in output and "anomaly_maps" in output: | ||
pixel_metric.to(self.device) | ||
pixel_metric.update(torch.squeeze(output["anomaly_maps"]), torch.squeeze(output["mask"].int())) | ||
if "masks" in output: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not for this PR but we need to be careful with these names as we might forget why we have mask and masks as two separate keys. And it might not be clear to new users |
||
self._update_pixel_metrics(pixel_metric, output) | ||
else: | ||
pixel_metric.update(torch.squeeze(output["anomaly_maps"]), torch.squeeze(output["mask"].int())) | ||
|
||
def _outputs_to_device(self, output: STEP_OUTPUT) -> STEP_OUTPUT | dict[str, Any]: | ||
if isinstance(output, dict): | ||
for key, value in output.items(): | ||
output[key] = self._outputs_to_device(value) | ||
elif isinstance(output, torch.Tensor): | ||
output = output.to(self.device) | ||
elif isinstance(output, list): | ||
for i, value in enumerate(output): | ||
output[i] = self._outputs_to_device(value) | ||
return output | ||
|
||
def _update_pixel_metrics(self, pixel_metric: AnomalibMetricCollection, output: STEP_OUTPUT) -> None: | ||
"""Handle metric updates when the SPRO metric is used alongside other pixel-level metrics.""" | ||
update = False | ||
for metric in pixel_metric.values(copy_state=False): | ||
if isinstance(metric, SPRO): | ||
metric.update(torch.squeeze(output["anomaly_maps"]), output["masks"]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could this be handled inside the metric so that we don't need metric-specific logic here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @djdameln this approach was actually to address this comment #1686 (comment) . I provided a more detailed explanation in this comment #1686 (comment) My initial approach was as follows:
My initial approach can handle all metrics (including SPRO and others) without metric-specific logic in the callbacks.metrics.py file. However, the problem is that we need an additional argument,
So, I changed the solution to avoid using |
||
else: | ||
metric.update(torch.squeeze(output["anomaly_maps"]), torch.squeeze(output["mask"].int())) | ||
update = True | ||
pixel_metric.set_update_called(update) | ||
|
||
@staticmethod | ||
def _log_metrics(pl_module: AnomalyModule) -> None: | ||
"""Log computed performance metrics.""" | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -26,7 +26,9 @@ | |||||||||||
def collate_fn(batch: list) -> dict[str, Any]: | ||||||||||||
"""Collate bounding boxes as lists. | ||||||||||||
|
||||||||||||
Bounding boxes are collated as a list of tensors, while the default collate function is used for all other entries. | ||||||||||||
Bounding boxes and `masks` (not `mask`) are collated as a list of tensors. If `masks` is exist, | ||||||||||||
the `mask_path` is also collated as a list since each element in the batch could be unequal. | ||||||||||||
For all other entries, the default collate function is used. | ||||||||||||
|
||||||||||||
Args: | ||||||||||||
batch (List): list of items in the batch where len(batch) is equal to the batch size. | ||||||||||||
|
@@ -40,6 +42,10 @@ def collate_fn(batch: list) -> dict[str, Any]: | |||||||||||
if "boxes" in elem: | ||||||||||||
# collate boxes as list | ||||||||||||
out_dict["boxes"] = [item.pop("boxes") for item in batch] | ||||||||||||
if "masks" in elem: | ||||||||||||
# collate masks and mask_path as list | ||||||||||||
out_dict["masks"] = [item.pop("masks") for item in batch] | ||||||||||||
out_dict["mask_path"] = [item.pop("mask_path") for item in batch] | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this needed? The default collate function should already collate the mask paths as a list. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @djdameln In the MVTec LOCO dataset, a single image could be paired with multiple mask paths stored as a list, where the number of mask paths for each image can vary. When using the default collate function, it will raise an error because the Torch stack function cannot stack items with different lengths. The code is needed to avoid such error. anomalib/src/anomalib/data/image/mvtec_loco.py Lines 81 to 85 in d9a2333
|
||||||||||||
# collate other data normally | ||||||||||||
out_dict.update({key: default_collate([item[key] for item in batch]) for key in elem}) | ||||||||||||
return out_dict | ||||||||||||
|
@@ -164,6 +170,9 @@ def _create_val_split(self) -> None: | |||||||||||
# converted from random training sample | ||||||||||||
self.train_data, normal_val_data = random_split(self.train_data, self.val_split_ratio, seed=self.seed) | ||||||||||||
self.val_data = SyntheticAnomalyDataset.from_dataset(normal_val_data) | ||||||||||||
elif self.val_split_mode == ValSplitMode.FROM_DIR: | ||||||||||||
# the val_data is prepared in subclass | ||||||||||||
pass | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This could lead to problems with thresholding and normalization, as the validation set of MVTec Loco only contains normal images by default. For accurate computation of the adaptive threshold and normalization statistics we would have to add some anomalous images here. Not sure what would be the best approach here though, since we do want to stick to the predefined splits for reproducibility. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
@djdameln Yeah, actually, this is why the FROM_DIR option was implemented in the first place, to use the predefined validation split as it is. However, if a user doesn't want to use this option for some reason, or perhaps for the reason you mentioned, can they freely change it to same_as_test or synthetic? |
||||||||||||
elif self.val_split_mode != ValSplitMode.NONE: | ||||||||||||
msg = f"Unknown validation split mode: {self.val_split_mode}" | ||||||||||||
raise ValueError(msg) | ||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would it be possible to somehow handle this within the SPRO itself? This is metric specific, would be great if we could sort it out in the metric
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@samet-akcay do you mean loading the config directly in the metric? i think it would be possible as long as we introduce a new parameter in the
yaml
, so that we can load the config in the metric. I will do as follows if it is okay.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh but this approach means that changing the
category
in the data config (.yaml) is not enough. because we also need to change the model (metric) config, i.e., the saturation_config path.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ashwinvaidya17, can you confirm whether we could use
class_path
andinit_args
for metrics above? The implementation would be quite straightforward in case it is possibleThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@samet-akcay @ashwinvaidya17 In the latest implementation, I have handled the loading of the saturation config in the metric itself. You can see what I modified in this commit: d9a2333.
I found there is an existing function, i.e.,
metric_collection_from_dicts
, which allows us to useclass_path
andinit_args
by passing adict[str, dict[str, Any]]
to the functionhttps://github.com/openvinotoolkit/anomalib/blob/main/src/anomalib/metrics/__init__.py#L184-L186
However, the current CLI doesn't support the dict type. So I modified this part by adding the dict type to make it work:
anomalib/src/anomalib/cli/cli.py
Lines 146 to 151 in d9a2333
I have tested this configuration in the YAML file, and the saturation_config path can be loaded on the metric side: