Skip to content

Commit

Permalink
metrics[1]: auc boxplot [GSoC 2023 @ OpenVINO] (#1294)
Browse files Browse the repository at this point in the history
* add per-image overlap (pimo)

* modif plot pimo curves

* add warning about memory

* tiny bug

* add tuto ipynb

* make image classes a return

* fix ipynb

* add tests for binclf curve

* add test to binclf

* add aupimo tests

* ruff

* Configure readthedocs via `.readthedocs.yaml` file (#1229)

* Update binclf_curve.py

* 🚚 Refactor Benchmarking Script (#1216)

* New printing stuff

* Remove dead code + address codacy issues

* Refactor try/except + log to comet/wandb during runs

* pre-commit error

* third-party configuration

---------

Co-authored-by: Ashwin Vaidya <[email protected]>

* Update CODEOWNERS

* Enable training with only normal images for MVTec (#1241)

* ignore mask check when dataset has only normal samples

* update changelog

* Revert "🚚 Refactor Benchmarking Script" (#1239)

Revert "🚚 Refactor Benchmarking Script (#1216)"

This reverts commit 784767f.

* Update benchmarking notebook (#1242)

* Fix metadata path

* Update benchmarking notebook

* add per-image overlap (pimo)

* modif plot pimo curves

* add warning about memory

* tiny bug

* add tuto ipynb

* make image classes a return

* fix ipynb

* add tests for binclf curve

* add test to binclf

* add aupimo tests

* ruff

* Update binclf_curve.py

* refactor from future pr

* add auc boxplot

* Apply suggestions from code review

* update demo nb

* correct tests

* add test

* fix test

* add plots tests

* add tests to pimo

* fix plt warning

* fix docstring warning

* add tests to common

* add tests for plot module and small fixes

* --amend

* clear ouputs in notebook

* correct typo

* correct codacy stuff

* correct codacy stuff

* merge

* fix kernel spec in 502_perimg_metrics.ipynb

* fix types in boxplot

* Update src/anomalib/utils/metrics/perimg/pimo.py

Co-authored-by: Samet Akcay <[email protected]>

---------

Co-authored-by: Samet Akcay <[email protected]>
Co-authored-by: Ashwin Vaidya <[email protected]>
Co-authored-by: Ashwin Vaidya <[email protected]>
Co-authored-by: Dick Ameln <[email protected]>
  • Loading branch information
5 people authored Sep 12, 2023
1 parent 103c22c commit c7baf40
Show file tree
Hide file tree
Showing 7 changed files with 770 additions and 367 deletions.
544 changes: 198 additions & 346 deletions notebooks/500_use_cases/502_perimg_metrics.ipynb

Large diffs are not rendered by default.

123 changes: 123 additions & 0 deletions src/anomalib/utils/metrics/perimg/common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import matplotlib as mpl
import numpy
import torch
from torch import Tensor

Expand Down Expand Up @@ -104,6 +106,36 @@ def _validate_image_classes(image_classes: Tensor):
)


def _validate_aucs(aucs: Tensor, nan_allowed: bool = False):
if not isinstance(aucs, Tensor):
raise ValueError(f"Expected argument `aucs` to be a Tensor, but got {type(aucs)}.")

if aucs.ndim != 1:
raise ValueError(f"Expected argument `aucs` to be a 1D tensor, but got {aucs.ndim}D tensor.")

if not torch.is_floating_point(aucs):
raise ValueError(f"Expected argument `aucs` to have dtype float, but got {aucs.dtype}.")

valid_aucs = aucs[~torch.isnan(aucs)] if nan_allowed else aucs

if torch.any((valid_aucs < 0) | (valid_aucs > 1)):
raise ValueError("Expected argument `aucs` to be in [0, 1], but got values outside this range.")


def _validate_image_class(image_class: int | None):
if image_class is None:
return

if not isinstance(image_class, int):
raise ValueError(f"Expected argument `image_class` to be either None or an int, but got {type(image_class)}.")

if image_class not in (0, 1):
raise ValueError(
"Expected argument `image_class` to be either 0, 1 or None (respec., 'normal', 'anomalous', or 'both') "
f"but got {image_class}."
)


def _validate_atleast_one_anomalous_image(image_classes: Tensor):
if (image_classes == 1).sum() == 0:
raise ValueError("Expected argument at least one anomalous image, but found none.")
Expand All @@ -112,3 +144,94 @@ def _validate_atleast_one_anomalous_image(image_classes: Tensor):
def _validate_atleast_one_normal_image(image_classes: Tensor):
if (image_classes == 0).sum() == 0:
raise ValueError("Expected argument at least one normal image, but found none.")


# =========================================== FUNCTIONAL ===========================================


def _perimg_boxplot_stats(
values: Tensor, image_classes: Tensor, only_class: int | None = None
) -> list[dict[str, str | int | float | None]]:
"""Compute boxplot statistics for a given tensor of values.
This function uses `matplotlib.cbook.boxplot_stats`, which is the same function used by `matplotlib.pyplot.boxplot`.
Args:
values (Tensor): Tensor of per-image values.
image_classes (Tensor): Tensor of image classes.
only_class (int | None): If not None, only compute statistics for images of the given class.
None means both image classes are used. Defaults to None.
Returns:
list[dict[str, str | int | float | None]]: List of boxplot statistics.
Each dictionary has the following keys:
- 'statistic': Name of the statistic.
- 'value': Value of the statistic (same units as `values`).
- 'nearest': Some statistics (e.g. 'mean') are not guaranteed to be in the tensor, so this is the
closest to the statistic in an actual image (i.e. in `values`).
- 'imgidx': Index of the image in `values` that has the `nearest` value to the statistic.
"""

_validate_image_classes(image_classes)
_validate_image_class(only_class)

if values.ndim != 1:
raise ValueError(f"Expected argument `values` to be a 1D tensor, but got {values.ndim}D tensor.")

if values.shape != image_classes.shape:
raise ValueError(
"Expected arguments `values` and `image_classes` to have the same shape, "
f"but got {values.shape} and {image_classes.shape}."
)

if only_class is not None and only_class not in image_classes:
raise ValueError(f"Argument `only_class` is {only_class}, but `image_classes` does not contain this class.")

# convert to numpy because of `matplotlib.cbook.boxplot_stats`
values = values.cpu().numpy()
image_classes = image_classes.cpu().numpy()

# only consider images of the given class
imgs_mask = numpy.ones_like(image_classes, dtype=bool) if only_class is None else (image_classes == only_class)
values = values[imgs_mask]
imgs_idxs = numpy.nonzero(imgs_mask)[0]

def arg_find_nearest(stat_value):
return (numpy.abs(values - stat_value)).argmin()

# function used in `matplotlib.boxplot`
boxplot_stats = mpl.cbook.boxplot_stats(values)[0] # [0] is for the only boxplot

records = []

def append_record(stat_, val_):
# make sure to use a value that is actually in the array
# because some statistics (e.g. 'mean') are not guaranteed to be in the array
invalues_idx = arg_find_nearest(val_)
nearest = values[invalues_idx]
imgidx = imgs_idxs[invalues_idx]
records.append(
dict(
statistic=stat_,
value=float(val_),
nearest=float(nearest),
imgidx=int(imgidx),
)
)

for stat, val in boxplot_stats.items():
if stat in ("iqr", "cilo", "cihi"):
continue

elif stat != "fliers":
append_record(stat, val)
continue

for val_ in val:
append_record(
"flierhi" if val_ > boxplot_stats["med"] else "flierlo",
val_,
)

records = sorted(records, key=lambda r: r["value"])
return records
92 changes: 86 additions & 6 deletions src/anomalib/utils/metrics/perimg/pimo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from collections import namedtuple

import matplotlib.pyplot as plt
import torch
from matplotlib.axes import Axes
from matplotlib.pyplot import Figure
Expand All @@ -25,11 +26,14 @@

from .binclf_curve import PerImageBinClfCurve
from .common import (
_perimg_boxplot_stats,
_validate_atleast_one_anomalous_image,
_validate_atleast_one_normal_image,
)
from .plot import (
plot_all_pimo_curves,
plot_aupimo_boxplot,
plot_boxplot_pimo_curves,
)

# =========================================== METRICS ===========================================
Expand All @@ -47,10 +51,10 @@
PImOResult.__doc__ = """PImO result (from `PImO.compute()`).
[0] thresholds: shape (num_thresholds,), a `float` dtype as given in update()
[1] fprs: shape (num_images, num_thresholds), dtype `float64`, \in [0, 1]
[2] shared_fpr: shape (num_thresholds,), dtype `float64`, \in [0, 1]
[3] tprs: shape (num_images, num_thresholds), dtype `float64`, \in [0, 1] for anom images, `nan` for norm images
[4] image_classes: shape (num_images,), dtype `int32`, \in {0, 1}
[1] fprs: shape (num_images, num_thresholds), dtype `float64`, \\in [0, 1]
[2] shared_fpr: shape (num_thresholds,), dtype `float64`, \\in [0, 1]
[3] tprs: shape (num_images, num_thresholds), dtype `float64`, \\in [0, 1] for anom images, `nan` for norm images
[4] image_classes: shape (num_images,), dtype `int32`, \\in {0, 1}
- `num_thresholds` is an attribute of `PImO` and is given in the constructor (from parent class).
- `num_images` depends on the data seen by the model at the update() calls.
Expand Down Expand Up @@ -144,7 +148,7 @@ def compute(self) -> tuple[PImOResult, Tensor]: # type: ignore
Returns: (PImOResult, aucs)
[0] PImOResult: PImOResult, see `anomalib.utils.metrics.perimg.pimo.PImOResult` for details.
[1] aucs: shape (num_images,), dtype `float64`, \in [0, 1]
[1] aucs: shape (num_images,), dtype `float64`, \\in [0, 1]
"""

if self.is_empty:
Expand Down Expand Up @@ -186,6 +190,56 @@ def plot_all_pimo_curves(
ax=ax,
)
ax.set_xlabel("Mean FPR on Normal Images")

return fig, ax

def boxplot_stats(self) -> list[dict[str, str | int | float | None]]:
"""Compute boxplot stats of AUPImO values (e.g. median, mean, quartiles, etc.).
Returns:
list[dict[str, str | int | float | None]]: List of AUCs statistics from a boxplot.
refer to `anomalib.utils.metrics.perimg.common._perimg_boxplot_stats()` for the keys and values.
"""
(_, __, ___, ____, image_classes), aucs = self.compute()
stats = _perimg_boxplot_stats(values=aucs, image_classes=image_classes, only_class=1)
return stats

def plot_boxplot_pimo_curves(
self,
ax: Axes | None = None,
) -> tuple[Figure | None, Axes]:
"""Plot shared FPR vs Per-Image Overlap (PImO) curves (boxplot images only).
The 'boxplot images' are those from the boxplot of AUPImO values (see `AUPImO.boxplot_stats()`).
Integration range is shown when `self.ubound < 1`.
"""

if self.is_empty:
return None, None

(thresholds, fprs, shared_fpr, tprs, image_classes), aucs = self.compute()
fig, ax = plot_boxplot_pimo_curves(
shared_fpr,
tprs,
image_classes,
self.boxplot_stats(),
ax=ax,
)
ax.set_xlabel("Mean FPR on Normal Images")

return fig, ax

def plot_boxplot(
self,
ax: Axes | None = None,
) -> tuple[Figure | None, Axes]:
"""Plot boxplot of AUPImO values."""

if self.is_empty:
return None, None

(thresholds, fprs, shared_fpr, tprs, image_classes), aucs = self.compute()
fig, ax = plot_aupimo_boxplot(aucs, image_classes, ax=ax)
return fig, ax

def plot(
Expand All @@ -194,7 +248,33 @@ def plot(
) -> tuple[Figure | None, Axes | ndarray]:
"""Plot AUPImO boxplot with its statistics' PImO curves."""

return self.plot_all_pimo_curves(ax)
if self.is_empty:
return None, None

if ax is None:
fig, ax = plt.subplots(1, 2, figsize=(14, 6), width_ratios=[6, 8])
fig.suptitle("Area Under the Per-Image Overlap (AUPImO) Curves")
fig.set_layout_engine("tight")
else:
fig, ax = (None, ax)

if isinstance(ax, Axes):
return self.plot_boxplot_pimo_curves(ax=ax)

if not isinstance(ax, ndarray):
raise ValueError(f"Expected argument `axes` to be a matplotlib Axes or ndarray, but got {type(ax)}.")

if ax.size != 2:
raise ValueError(
f"Expected argument `axes` , when type `ndarray`, to be of size 2, but got size {ax.size}."
)

ax = ax.flatten()
self.plot_boxplot(ax=ax[0])
ax[0].set_title("AUC Boxplot")
self.plot_boxplot_pimo_curves(ax=ax[1])
ax[1].set_title("Curves")
return fig, ax


class AULogPImO(PImO):
Expand Down
Loading

0 comments on commit c7baf40

Please sign in to comment.