Skip to content

Commit

Permalink
Minor fixes: Update callbacks to AnomalyModule (#208)
Browse files Browse the repository at this point in the history
  • Loading branch information
ashwinvaidya17 authored Apr 11, 2022
1 parent 9c6e93e commit c16d14c
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 21 deletions.
13 changes: 7 additions & 6 deletions anomalib/utils/callbacks/cdf_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from torch.distributions import LogNormal

from anomalib.models import get_model
from anomalib.models.components import AnomalyModule
from anomalib.post_processing.normalization.cdf import normalize, standardize


Expand All @@ -32,12 +33,12 @@ def __init__(self):
self.image_dist: Optional[LogNormal] = None
self.pixel_dist: Optional[LogNormal] = None

def on_test_start(self, _trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
def on_test_start(self, _trainer: pl.Trainer, pl_module: AnomalyModule) -> None:
"""Called when the test begins."""
pl_module.image_metrics.F1.threshold = 0.5
pl_module.pixel_metrics.F1.threshold = 0.5

def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: AnomalyModule) -> None:
"""Called when the validation starts after training.
Use the current model to compute the anomaly score distributions
Expand All @@ -49,7 +50,7 @@ def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.Lightn
def on_validation_batch_end(
self,
_trainer: pl.Trainer,
pl_module: pl.LightningModule,
pl_module: AnomalyModule,
outputs: Optional[STEP_OUTPUT],
_batch: Any,
_batch_idx: int,
Expand All @@ -61,7 +62,7 @@ def on_validation_batch_end(
def on_test_batch_end(
self,
_trainer: pl.Trainer,
pl_module: pl.LightningModule,
pl_module: AnomalyModule,
outputs: Optional[STEP_OUTPUT],
_batch: Any,
_batch_idx: int,
Expand All @@ -74,7 +75,7 @@ def on_test_batch_end(
def on_predict_batch_end(
self,
_trainer: pl.Trainer,
pl_module: pl.LightningModule,
pl_module: AnomalyModule,
outputs: Dict,
_batch: Any,
_batch_idx: int,
Expand Down Expand Up @@ -120,7 +121,7 @@ def _standardize_batch(outputs: STEP_OUTPUT, pl_module) -> None:
)

@staticmethod
def _normalize_batch(outputs: STEP_OUTPUT, pl_module: pl.LightningModule) -> None:
def _normalize_batch(outputs: STEP_OUTPUT, pl_module: AnomalyModule) -> None:
outputs["pred_scores"] = normalize(outputs["pred_scores"], pl_module.image_threshold.value)
if "anomaly_maps" in outputs.keys():
outputs["anomaly_maps"] = normalize(outputs["anomaly_maps"], pl_module.pixel_threshold.value)
9 changes: 5 additions & 4 deletions anomalib/utils/callbacks/min_max_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,22 @@
from pytorch_lightning import Callback
from pytorch_lightning.utilities.types import STEP_OUTPUT

from anomalib.models.components import AnomalyModule
from anomalib.post_processing.normalization.min_max import normalize


class MinMaxNormalizationCallback(Callback):
"""Callback that normalizes the image-level and pixel-level anomaly scores using min-max normalization."""

def on_test_start(self, _trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
def on_test_start(self, _trainer: pl.Trainer, pl_module: AnomalyModule) -> None:
"""Called when the test begins."""
pl_module.image_metrics.F1.threshold = 0.5
pl_module.pixel_metrics.F1.threshold = 0.5

def on_validation_batch_end(
self,
_trainer: pl.Trainer,
pl_module: pl.LightningModule,
pl_module: AnomalyModule,
outputs: STEP_OUTPUT,
_batch: Any,
_batch_idx: int,
Expand All @@ -49,7 +50,7 @@ def on_validation_batch_end(
def on_test_batch_end(
self,
_trainer: pl.Trainer,
pl_module: pl.LightningModule,
pl_module: AnomalyModule,
outputs: STEP_OUTPUT,
_batch: Any,
_batch_idx: int,
Expand All @@ -61,7 +62,7 @@ def on_test_batch_end(
def on_predict_batch_end(
self,
_trainer: pl.Trainer,
pl_module: pl.LightningModule,
pl_module: AnomalyModule,
outputs: Dict,
_batch: Any,
_batch_idx: int,
Expand Down
6 changes: 4 additions & 2 deletions anomalib/utils/callbacks/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
# and limitations under the License.

import torch
from pytorch_lightning import Callback, LightningModule
from pytorch_lightning import Callback

from anomalib.models.components import AnomalyModule


class LoadModelCallback(Callback):
Expand All @@ -24,7 +26,7 @@ class LoadModelCallback(Callback):
def __init__(self, weights_path):
self.weights_path = weights_path

def on_test_start(self, trainer, pl_module: LightningModule) -> None: # pylint: disable=W0613
def on_test_start(self, trainer, pl_module: AnomalyModule) -> None: # pylint: disable=W0613
"""Call when the test begins.
Loads the model weights from ``weights_path`` into the PyTorch module.
Expand Down
7 changes: 3 additions & 4 deletions anomalib/utils/callbacks/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
# and limitations under the License.

import os
from typing import Tuple, cast
from typing import Tuple

from pytorch_lightning import Callback, LightningModule
from pytorch_lightning import Callback

from anomalib.deploy import export_convert
from anomalib.models.components import AnomalyModule
Expand All @@ -39,15 +39,14 @@ def __init__(self, input_size: Tuple[int, int], dirpath: str, filename: str):
self.dirpath = dirpath
self.filename = filename

def on_train_end(self, trainer, pl_module: LightningModule) -> None: # pylint: disable=W0613
def on_train_end(self, trainer, pl_module: AnomalyModule) -> None: # pylint: disable=W0613
"""Call when the train ends.
Converts the model to ``onnx`` format and then calls OpenVINO's model optimizer to get the
``.xml`` and ``.bin`` IR files.
"""
os.makedirs(self.dirpath, exist_ok=True)
onnx_path = os.path.join(self.dirpath, self.filename + ".onnx")
pl_module = cast(AnomalyModule, pl_module)
export_convert(
model=pl_module,
input_size=self.input_size,
Expand Down
6 changes: 3 additions & 3 deletions anomalib/utils/callbacks/visualizer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _add_images(
def on_test_batch_end(
self,
_trainer: pl.Trainer,
pl_module: pl.LightningModule,
pl_module: AnomalyModule,
outputs: Optional[STEP_OUTPUT],
_batch: Any,
_batch_idx: int,
Expand Down Expand Up @@ -150,15 +150,15 @@ def on_test_batch_end(
self._add_images(visualizer, pl_module, Path(filename))
visualizer.close()

def on_test_end(self, _trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
def on_test_end(self, _trainer: pl.Trainer, pl_module: AnomalyModule) -> None:
"""Sync logs.
Currently only ``AnomalibWandbLogger`` is called from this method. This is because logging as a single batch
ensures that all images appear as part of the same step.
Args:
_trainer (pl.Trainer): Pytorch Lightning trainer (unused)
pl_module (pl.LightningModule): Anomaly module
pl_module (AnomalyModule): Anomaly module
"""
if pl_module.logger is not None and isinstance(pl_module.logger, AnomalibWandbLogger):
pl_module.logger.save()
10 changes: 8 additions & 2 deletions tests/nightly/models/test_model_nightly.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,19 @@ def _test_metrics(self, trainer, config, model, datamodule):
threshold = thresholds[config.model.name][config.dataset.category]
if "optimization" in config.keys() and config.optimization.nncf.apply:
threshold = threshold.nncf
if not (np.isclose(results["image_AUROC"], threshold["image_AUROC"], rtol=0.02) or (results["image_AUROC"] >= threshold["image_AUROC"])):
if not (
np.isclose(results["image_AUROC"], threshold["image_AUROC"], rtol=0.02)
or (results["image_AUROC"] >= threshold["image_AUROC"])
):
raise AssertionError(
f"results['image_AUROC']:{results['image_AUROC']} >= threshold['image_AUROC']:{threshold['image_AUROC']}"
)

if config.dataset.task == "segmentation":
if not (np.isclose(results["pixel_AUROC"] ,threshold["pixel_AUROC"], rtol=0.02) or (results["pixel_AUROC"] >= threshold["pixel_AUROC"])):
if not (
np.isclose(results["pixel_AUROC"], threshold["pixel_AUROC"], rtol=0.02)
or (results["pixel_AUROC"] >= threshold["pixel_AUROC"])
):
raise AssertionError(
f"results['pixel_AUROC']:{results['pixel_AUROC']} >= threshold['pixel_AUROC']:{threshold['pixel_AUROC']}"
)
Expand Down

0 comments on commit c16d14c

Please sign in to comment.