diff --git a/anomalib/core/callbacks/cdf_normalization.py b/anomalib/core/callbacks/cdf_normalization.py index ac6fab98bb..1786591139 100644 --- a/anomalib/core/callbacks/cdf_normalization.py +++ b/anomalib/core/callbacks/cdf_normalization.py @@ -22,10 +22,8 @@ def on_test_start(self, _trainer: pl.Trainer, pl_module: pl.LightningModule) -> pl_module.image_metrics.F1.threshold = 0.5 pl_module.pixel_metrics.F1.threshold = 0.5 - def on_train_epoch_end( - self, trainer: pl.Trainer, pl_module: pl.LightningModule, _unused: Optional[Any] = None - ) -> None: - """Called when the train epoch ends. + def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Called when the validation starts after training. Use the current model to compute the anomaly score distributions of the normal training data. This is needed after every epoch, because the statistics must be diff --git a/anomalib/core/callbacks/nncf_callback.py b/anomalib/core/callbacks/nncf_callback.py index fa193c492f..8d0013fce8 100644 --- a/anomalib/core/callbacks/nncf_callback.py +++ b/anomalib/core/callbacks/nncf_callback.py @@ -13,8 +13,6 @@ from pytorch_lightning import Callback from torch.utils.data.dataloader import DataLoader -from anomalib.data import get_datamodule - def criterion_fn(outputs, criterion): """Calls the criterion function on outputs.""" @@ -76,21 +74,18 @@ def __init__(self, config: Union[ListConfig, DictConfig], dirpath: str, filename self.dirpath = dirpath self.filename = filename - # we need to create a datamodule here to obtain the init loader - datamodule = get_datamodule(config) - datamodule.setup() - self.train_loader = datamodule.train_dataloader() - self.comp_ctrl: Optional[CompressionAlgorithmController] = None self.compression_scheduler: CompressionScheduler - def setup(self, _: pl.Trainer, pl_module: pl.LightningModule, __: Optional[str] = None) -> None: + def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: Optional[str] = None) -> None: """Call when fit or test begins. Takes the pytorch model and wraps it using the compression controller so that it is ready for nncf fine-tuning. """ if self.comp_ctrl is None: - init_loader = InitLoader(self.train_loader) + # NOTE: trainer.datamodule returns the following error + # "Trainer" has no attribute "datamodule" [attr-defined] + init_loader = InitLoader(trainer.datamodule.train_dataloader()) # type: ignore nncf_config = register_default_init_args( self.nncf_config, init_loader, pl_module.model.loss, criterion_fn=criterion_fn ) @@ -99,7 +94,12 @@ def setup(self, _: pl.Trainer, pl_module: pl.LightningModule, __: Optional[str] self.compression_scheduler = self.comp_ctrl.scheduler def on_train_batch_start( - self, trainer, _pl_module: pl.LightningModule, _batch: Any, _batch_idx: int, _dataloader_idx: int + self, + trainer: pl.Trainer, + _pl_module: pl.LightningModule, + _batch: Any, + _batch_idx: int, + _unused: Optional[int] = 0, ) -> None: """Call when the train batch begins. @@ -109,7 +109,7 @@ def on_train_batch_start( if self.comp_ctrl is not None: trainer.model.loss_val = self.comp_ctrl.loss() - def on_train_end(self, _trainer, _pl_module: pl.LightningModule) -> None: + def on_train_end(self, _trainer: pl.Trainer, _pl_module: pl.LightningModule) -> None: """Call when the train ends. Exports onnx model and if compression controller is not None, uses the onnx model to generate the OpenVINO IR. diff --git a/anomalib/core/model/anomaly_module.py b/anomalib/core/model/anomaly_module.py index f120fb48d6..8cea7d2ec7 100644 --- a/anomalib/core/model/anomaly_module.py +++ b/anomalib/core/model/anomaly_module.py @@ -14,7 +14,8 @@ # See the License for the specific language governing permissions # and limitations under the License. -from typing import List, Union +from abc import ABC +from typing import Any, List, Optional, Union import pytorch_lightning as pl from omegaconf import DictConfig, ListConfig @@ -30,7 +31,7 @@ ) -class AnomalyModule(pl.LightningModule): +class AnomalyModule(pl.LightningModule, ABC): """AnomalyModule to train, validate, predict and test images. Acts as a base class for all the Anomaly Modules in the library. @@ -77,7 +78,7 @@ def validation_step(self, batch, batch_idx) -> dict: # type: ignore # pylint: """To be implemented in the subclasses.""" raise NotImplementedError - def predict_step(self, batch, batch_idx, _): # pylint: disable=arguments-differ, signature-differs + def predict_step(self, batch: Any, batch_idx: int, _dataloader_idx: Optional[int] = None) -> Any: """Step function called during :meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`. By default, it calls :meth:`~pytorch_lightning.core.lightning.LightningModule.forward`. @@ -86,7 +87,7 @@ def predict_step(self, batch, batch_idx, _): # pylint: disable=arguments-differ Args: batch (Tensor): Current batch batch_idx (int): Index of current batch - dataloader_idx (int): Index of the current dataloader + _dataloader_idx (int): Index of the current dataloader Return: Predicted output diff --git a/anomalib/data/__init__.py b/anomalib/data/__init__.py index a5f5d2d6ec..302475d3e8 100644 --- a/anomalib/data/__init__.py +++ b/anomalib/data/__init__.py @@ -19,6 +19,7 @@ from omegaconf import DictConfig, ListConfig from pytorch_lightning import LightningDataModule +from .inference import InferenceDataset from .mvtec import MVTecDataModule @@ -48,3 +49,6 @@ def get_datamodule(config: Union[DictConfig, ListConfig]) -> LightningDataModule raise ValueError("Unknown dataset!") return datamodule + + +__all__ = ["get_datamodule", "InferenceDataset"] diff --git a/anomalib/data/inference.py b/anomalib/data/inference.py new file mode 100644 index 0000000000..094d4d72bc --- /dev/null +++ b/anomalib/data/inference.py @@ -0,0 +1,67 @@ +"""Inference Dataset.""" + +# Copyright (C) 2020 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. + +from pathlib import Path +from typing import Any, Optional, Tuple, Union + +import albumentations as A +from torch.utils.data.dataset import Dataset + +from anomalib.data.transforms import PreProcessor +from anomalib.data.utils import get_image_filenames, read_image + + +class InferenceDataset(Dataset): + """Inference Dataset to perform prediction.""" + + def __init__( + self, + path: Union[str, Path], + pre_process: Optional[PreProcessor] = None, + image_size: Optional[Union[int, Tuple[int, int]]] = None, + transform_config: Optional[Union[str, A.Compose]] = None, + ) -> None: + """Inference Dataset to perform prediction. + + Args: + path (Union[str, Path]): Path to an image or image-folder. + pre_process (Optional[PreProcessor], optional): Pre-Processing transforms to + pre-process the input dataset. Defaults to None. + image_size (Optional[Union[int, Tuple[int, int]]], optional): Target image size + to resize the original image. Defaults to None. + transform_config (Optional[Union[str, A.Compose]], optional): Configuration file + parse the albumentation transforms. Defaults to None. + """ + super().__init__() + + self.image_filenames = get_image_filenames(path) + + if pre_process is None: + self.pre_process = PreProcessor(transform_config, image_size) + else: + self.pre_process = pre_process + + def __len__(self) -> int: + """Get the number of images in the given path.""" + return len(self.image_filenames) + + def __getitem__(self, index: int) -> Any: + """Get the image based on the `index`.""" + image_filename = self.image_filenames[index] + image = read_image(path=image_filename) + pre_processed = self.pre_process(image=image) + + return pre_processed diff --git a/anomalib/data/mvtec.py b/anomalib/data/mvtec.py index 8f44558aa7..ccf8ca9d2f 100644 --- a/anomalib/data/mvtec.py +++ b/anomalib/data/mvtec.py @@ -34,11 +34,13 @@ import pandas as pd from pandas.core.frame import DataFrame from pytorch_lightning.core.datamodule import LightningDataModule +from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS from torch import Tensor from torch.utils.data import DataLoader from torch.utils.data.dataset import Dataset from torchvision.datasets.folder import VisionDataset +from anomalib.data.inference import InferenceDataset from anomalib.data.transforms import PreProcessor from anomalib.data.utils import read_image from anomalib.utils.download_progress_bar import DownloadProgressBar @@ -417,8 +419,10 @@ def __init__( self.root = root if isinstance(root, Path) else Path(root) self.category = category self.dataset_path = self.root / self.category + self.transform_config = transform_config + self.image_size = image_size - self.pre_process = PreProcessor(config=transform_config, image_size=image_size) + self.pre_process = PreProcessor(config=self.transform_config, image_size=self.image_size) self.train_batch_size = train_batch_size self.test_batch_size = test_batch_size @@ -431,13 +435,25 @@ def __init__( self.test_data: Dataset if create_validation_set: self.val_data: Dataset + self.inference_data: Dataset def setup(self, stage: Optional[str] = None) -> None: """Setup train, validation and test data. Args: stage: Optional[str]: Train/Val/Test stages. (Default value = None) + """ + if stage in (None, "fit"): + self.train_data = MVTec( + root=self.root, + category=self.category, + pre_process=self.pre_process, + split="train", + seed=self.seed, + create_validation_set=self.create_validation_set, + ) + if self.create_validation_set: self.val_data = MVTec( root=self.root, @@ -447,6 +463,7 @@ def setup(self, stage: Optional[str] = None) -> None: seed=self.seed, create_validation_set=self.create_validation_set, ) + self.test_data = MVTec( root=self.root, category=self.category, @@ -455,25 +472,27 @@ def setup(self, stage: Optional[str] = None) -> None: seed=self.seed, create_validation_set=self.create_validation_set, ) - if stage in (None, "fit"): - self.train_data = MVTec( - root=self.root, - category=self.category, - pre_process=self.pre_process, - split="train", - seed=self.seed, - create_validation_set=self.create_validation_set, + + if stage == "predict": + self.inference_data = InferenceDataset( + path=self.root, image_size=self.image_size, transform_config=self.transform_config ) - def train_dataloader(self) -> DataLoader: + def train_dataloader(self) -> TRAIN_DATALOADERS: """Get train dataloader.""" return DataLoader(self.train_data, shuffle=True, batch_size=self.train_batch_size, num_workers=self.num_workers) - def val_dataloader(self) -> DataLoader: + def val_dataloader(self) -> EVAL_DATALOADERS: """Get validation dataloader.""" dataset = self.val_data if self.create_validation_set else self.test_data return DataLoader(dataset=dataset, shuffle=False, batch_size=self.test_batch_size, num_workers=self.num_workers) - def test_dataloader(self) -> DataLoader: + def test_dataloader(self) -> EVAL_DATALOADERS: """Get test dataloader.""" return DataLoader(self.test_data, shuffle=False, batch_size=self.test_batch_size, num_workers=self.num_workers) + + def predict_dataloader(self) -> EVAL_DATALOADERS: + """Get predict dataloader.""" + return DataLoader( + self.inference_data, shuffle=False, batch_size=self.test_batch_size, num_workers=self.num_workers + ) diff --git a/anomalib/data/utils.py b/anomalib/data/utils.py index 28cf64388e..e403d787ae 100644 --- a/anomalib/data/utils.py +++ b/anomalib/data/utils.py @@ -15,10 +15,38 @@ # and limitations under the License. from pathlib import Path -from typing import Union +from typing import List, Union import cv2 import numpy as np +from torchvision.datasets.folder import IMG_EXTENSIONS + + +def get_image_filenames(path: Union[str, Path]) -> List[str]: + """Get image filenames. + + Args: + path (Union[str, Path]): Path to image or image-folder. + + Returns: + List[str]: List of image filenames + + """ + image_filenames: List[str] + + if isinstance(path, str): + path = Path(path) + + if path.is_file() and path.suffix in IMG_EXTENSIONS: + image_filenames = [str(path)] + + if path.is_dir(): + image_filenames = [str(p) for p in path.glob("**/*") if p.suffix in IMG_EXTENSIONS] + + if len(image_filenames) == 0: + raise ValueError(f"Found 0 images in {path}") + + return image_filenames def read_image(path: Union[str, Path]) -> np.ndarray: diff --git a/anomalib/models/cflow/config.yaml b/anomalib/models/cflow/config.yaml index c893da7fac..d1d8c26bda 100644 --- a/anomalib/models/cflow/config.yaml +++ b/anomalib/models/cflow/config.yaml @@ -3,7 +3,7 @@ dataset: format: mvtec path: ./datasets/MVTec url: ftp://guest:GU.205dldo@ftp.softronics.ch/mvtec_anomaly_detection/mvtec_anomaly_detection.tar.xz - category: leather + category: bottle task: segmentation label_format: None image_size: 256 @@ -47,7 +47,6 @@ trainer: accelerator: null accumulate_grad_batches: 1 amp_backend: native - amp_level: O2 auto_lr_find: false auto_scale_batch_size: false auto_select_gpus: false @@ -55,10 +54,8 @@ trainer: check_val_every_n_epoch: 1 checkpoint_callback: true default_root_dir: null - deterministic: true - distributed_backend: null + deterministic: false fast_dev_run: false - flush_logs_every_n_steps: 100 gpus: 1 gradient_clip_val: 0 limit_predict_batches: 1.0 @@ -68,7 +65,7 @@ trainer: log_every_n_steps: 50 log_gpu_memory: null max_epochs: 50 - max_steps: null + max_steps: -1 min_epochs: null min_steps: null move_metrics_to_cpu: false @@ -83,14 +80,12 @@ trainer: process_position: 0 profiler: null progress_bar_refresh_rate: null - reload_dataloaders_every_epoch: false replace_sampler_ddp: true stochastic_weight_avg: false sync_batchnorm: false terminate_on_nan: false tpu_cores: null track_grad_norm: -1 - truncated_bptt_steps: null val_check_interval: 1.0 weights_save_path: null weights_summary: top diff --git a/anomalib/models/dfkde/config.yaml b/anomalib/models/dfkde/config.yaml index aa87af9964..e73b636208 100644 --- a/anomalib/models/dfkde/config.yaml +++ b/anomalib/models/dfkde/config.yaml @@ -3,7 +3,7 @@ dataset: format: mvtec path: ./datasets/MVTec url: ftp://guest:GU.205dldo@ftp.softronics.ch/mvtec_anomaly_detection/mvtec_anomaly_detection.tar.xz - category: leather + category: bottle task: classification label_format: None image_size: 256 @@ -35,18 +35,15 @@ trainer: accelerator: null accumulate_grad_batches: 1 amp_backend: native - amp_level: O2 auto_lr_find: false auto_scale_batch_size: false auto_select_gpus: false benchmark: false - check_val_every_n_epoch: 1 + check_val_every_n_epoch: 1 # Don't validate before extracting features. checkpoint_callback: true default_root_dir: null deterministic: false - distributed_backend: null fast_dev_run: false - flush_logs_every_n_steps: 100 gpus: 1 gradient_clip_val: 0 limit_predict_batches: 1.0 @@ -56,7 +53,7 @@ trainer: log_every_n_steps: 50 log_gpu_memory: null max_epochs: 1 - max_steps: null + max_steps: -1 min_epochs: null min_steps: null move_metrics_to_cpu: false @@ -71,14 +68,12 @@ trainer: process_position: 0 profiler: null progress_bar_refresh_rate: null - reload_dataloaders_every_epoch: false replace_sampler_ddp: true stochastic_weight_avg: false sync_batchnorm: false terminate_on_nan: false tpu_cores: null track_grad_norm: -1 - truncated_bptt_steps: null - val_check_interval: 1.0 + val_check_interval: 1.0 # Don't validate before extracting features. weights_save_path: null weights_summary: top diff --git a/anomalib/models/dfkde/model.py b/anomalib/models/dfkde/model.py index 66cf71e3a8..e14b33c800 100644 --- a/anomalib/models/dfkde/model.py +++ b/anomalib/models/dfkde/model.py @@ -14,12 +14,13 @@ # See the License for the specific language governing permissions # and limitations under the License. -from typing import Any, Dict, List, Union +from typing import List, Union import torch import torchvision from omegaconf.dictconfig import DictConfig from omegaconf.listconfig import ListConfig +from torch import Tensor from anomalib.core.model import AnomalyModule from anomalib.core.model.feature_extractor import FeatureExtractor @@ -47,17 +48,19 @@ def __init__(self, hparams: Union[DictConfig, ListConfig]): threshold_offset=self.threshold_offset, ) self.automatic_optimization = False + self.embeddings: List[Tensor] = [] @staticmethod def configure_optimizers(): """DFKDE doesn't require optimization, therefore returns no optimizers.""" return None - def training_step(self, batch, _): # pylint: disable=arguments-differ + def training_step(self, batch, _batch_idx): # pylint: disable=arguments-differ """Training Step of DFKDE. For each batch, features are extracted from the CNN. Args: - batch (Tensor): Input batch + batch (Dict[str, Any]): Batch containing image filename, image, label and mask + _batch_idx: Index of the batch. Returns: Deep CNN features. @@ -65,21 +68,21 @@ def training_step(self, batch, _): # pylint: disable=arguments-differ self.feature_extractor.eval() layer_outputs = self.feature_extractor(batch["image"]) - feature_vector = torch.hstack(list(layer_outputs.values())).detach().squeeze() - return {"feature_vector": feature_vector} - - def training_epoch_end(self, outputs: List[Dict[str, Any]]) -> None: - """Fit a KDE model on deep CNN features. - - Args: - outputs (List[Dict[str, Any]]): Batch of outputs from the training step - - Returns: - None - """ - - feature_stack = torch.vstack([output["feature_vector"] for output in outputs]) - self.normality_model.fit(feature_stack) + embedding = torch.hstack(list(layer_outputs.values())).detach().squeeze() + + # NOTE: `self.embedding` appends each batch embedding to + # store the training set embedding. We manually append these + # values mainly due to the new order of hooks introduced after PL v1.4.0 + # https://github.com/PyTorchLightning/pytorch-lightning/pull/7357 + self.embeddings.append(embedding) + + def on_validation_start(self) -> None: + """Fit a KDE Model to the embedding collected from the training set.""" + # NOTE: Previous anomalib versions fit Gaussian at the end of the epoch. + # This is not possible anymore with PyTorch Lightning v1.4.0 since validation + # is run within train epoch. + embeddings = torch.vstack(self.embeddings) + self.normality_model.fit(embeddings) def validation_step(self, batch, _): # pylint: disable=arguments-differ """Validation Step of DFKDE. diff --git a/anomalib/models/dfm/config.yaml b/anomalib/models/dfm/config.yaml index 72f409f002..9ab2b04161 100755 --- a/anomalib/models/dfm/config.yaml +++ b/anomalib/models/dfm/config.yaml @@ -3,7 +3,7 @@ dataset: format: mvtec path: ./datasets/MVTec url: ftp://guest:GU.205dldo@ftp.softronics.ch/mvtec_anomaly_detection/mvtec_anomaly_detection.tar.xz - category: leather + category: bottle task: classification label_format: None image_size: 256 @@ -25,7 +25,7 @@ model: project: seed: 42 path: ./results - log_images_to: [local] + log_images_to: [] logger: false save_to_csv: false @@ -34,18 +34,15 @@ trainer: accelerator: null accumulate_grad_batches: 1 amp_backend: native - amp_level: O2 auto_lr_find: false auto_scale_batch_size: false auto_select_gpus: false benchmark: false - check_val_every_n_epoch: 1 + check_val_every_n_epoch: 1 # Don't validate before extracting features. checkpoint_callback: true default_root_dir: null deterministic: false - distributed_backend: null fast_dev_run: false - flush_logs_every_n_steps: 100 gpus: 1 gradient_clip_val: 0 limit_predict_batches: 1.0 @@ -55,7 +52,7 @@ trainer: log_every_n_steps: 50 log_gpu_memory: null max_epochs: 1 - max_steps: null + max_steps: -1 min_epochs: null min_steps: null move_metrics_to_cpu: false @@ -70,14 +67,12 @@ trainer: process_position: 0 profiler: null progress_bar_refresh_rate: null - reload_dataloaders_every_epoch: false replace_sampler_ddp: true stochastic_weight_avg: false sync_batchnorm: false terminate_on_nan: false tpu_cores: null track_grad_norm: -1 - truncated_bptt_steps: null - val_check_interval: 1.0 + val_check_interval: 1.0 # Don't validate before extracting features. weights_save_path: null weights_summary: top diff --git a/anomalib/models/dfm/model.py b/anomalib/models/dfm/model.py index 423830aa56..7a4a9f1683 100644 --- a/anomalib/models/dfm/model.py +++ b/anomalib/models/dfm/model.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions # and limitations under the License. -from typing import Dict, List, Union +from typing import List, Union import torch import torchvision @@ -37,6 +37,7 @@ def __init__(self, hparams: Union[DictConfig, ListConfig]): self.dfm_model = DFMModel(n_comps=hparams.model.pca_level, score_type=hparams.model.score_type) self.automatic_optimization = False + self.embeddings: List[Tensor] = [] @staticmethod def configure_optimizers() -> None: @@ -58,21 +59,21 @@ def training_step(self, batch, _): # pylint: disable=arguments-differ self.feature_extractor.eval() layer_outputs = self.feature_extractor(batch["image"]) - feature_vector = torch.hstack(list(layer_outputs.values())).detach().squeeze() - return {"feature_vector": feature_vector} - - def training_epoch_end(self, outputs: List[Dict[str, Tensor]]) -> None: - """Fit a KDE model on deep CNN features. - - Args: - outputs (List[Dict[str, Tensor]]): Batch of outputs from the training step - - Returns: - None - """ - - feature_stack = torch.vstack([output["feature_vector"] for output in outputs]) - self.dfm_model.fit(feature_stack) + embedding = torch.hstack(list(layer_outputs.values())).detach().squeeze() + + # NOTE: `self.embedding` appends each batch embedding to + # store the training set embedding. We manually append these + # values mainly due to the new order of hooks introduced after PL v1.4.0 + # https://github.com/PyTorchLightning/pytorch-lightning/pull/7357 + self.embeddings.append(embedding) + + def on_validation_start(self) -> None: + """Fit a KDE Model to the embedding collected from the training set.""" + # NOTE: Previous anomalib versions fit Gaussian at the end of the epoch. + # This is not possible anymore with PyTorch Lightning v1.4.0 since validation + # is run within train epoch. + embeddings = torch.vstack(self.embeddings) + self.dfm_model.fit(embeddings) def validation_step(self, batch, _): # pylint: disable=arguments-differ """Validation Step of DFM. diff --git a/anomalib/models/ganomaly/config.yaml b/anomalib/models/ganomaly/config.yaml index f6d9a316e7..2837bf13d1 100644 --- a/anomalib/models/ganomaly/config.yaml +++ b/anomalib/models/ganomaly/config.yaml @@ -71,7 +71,6 @@ trainer: accelerator: null accumulate_grad_batches: 1 amp_backend: native - amp_level: O2 auto_lr_find: false auto_scale_batch_size: false auto_select_gpus: false @@ -79,10 +78,8 @@ trainer: check_val_every_n_epoch: 2 checkpoint_callback: true default_root_dir: null - deterministic: true - distributed_backend: null + deterministic: false fast_dev_run: false - flush_logs_every_n_steps: 100 gpus: 1 gradient_clip_val: 0 limit_predict_batches: 1.0 @@ -107,14 +104,12 @@ trainer: process_position: 0 profiler: null progress_bar_refresh_rate: null - reload_dataloaders_every_epoch: false replace_sampler_ddp: true stochastic_weight_avg: false sync_batchnorm: false terminate_on_nan: false tpu_cores: null track_grad_norm: -1 - truncated_bptt_steps: null val_check_interval: 1.0 weights_save_path: null weights_summary: top diff --git a/anomalib/models/padim/config.yaml b/anomalib/models/padim/config.yaml index e3b988486e..8b5df3cc37 100644 --- a/anomalib/models/padim/config.yaml +++ b/anomalib/models/padim/config.yaml @@ -35,7 +35,7 @@ model: project: seed: 42 path: ./results - log_images_to: [] + log_images_to: ["local"] logger: false save_to_csv: false @@ -60,18 +60,15 @@ trainer: accelerator: null accumulate_grad_batches: 1 amp_backend: native - amp_level: O2 auto_lr_find: false auto_scale_batch_size: false auto_select_gpus: false benchmark: false - check_val_every_n_epoch: 1 + check_val_every_n_epoch: 1 # Don't validate before extracting features. checkpoint_callback: true default_root_dir: null - deterministic: true - distributed_backend: null + deterministic: false fast_dev_run: false - flush_logs_every_n_steps: 100 gpus: 1 gradient_clip_val: 0 limit_predict_batches: 1.0 @@ -81,7 +78,7 @@ trainer: log_every_n_steps: 50 log_gpu_memory: null max_epochs: 1 - max_steps: null + max_steps: -1 min_epochs: null min_steps: null move_metrics_to_cpu: false @@ -96,14 +93,12 @@ trainer: process_position: 0 profiler: null progress_bar_refresh_rate: null - reload_dataloaders_every_epoch: false replace_sampler_ddp: true stochastic_weight_avg: false sync_batchnorm: false terminate_on_nan: false tpu_cores: null track_grad_norm: -1 - truncated_bptt_steps: null - val_check_interval: 1.0 + val_check_interval: 1.0 # Don't validate before extracting features. weights_save_path: null weights_summary: top diff --git a/anomalib/models/padim/model.py b/anomalib/models/padim/model.py index d1edf24679..ca0343a5c2 100644 --- a/anomalib/models/padim/model.py +++ b/anomalib/models/padim/model.py @@ -281,7 +281,7 @@ class PadimLightning(AnomalyModule): def __init__(self, hparams: Union[DictConfig, ListConfig]): super().__init__(hparams) self.layers = hparams.model.layers - self.model = PadimModel( + self.model: PadimModel = PadimModel( layers=hparams.model.layers, input_size=hparams.model.input_size, tile_size=hparams.dataset.tiling.tile_size, @@ -292,38 +292,38 @@ def __init__(self, hparams: Union[DictConfig, ListConfig]): self.stats: List[Tensor] = [] self.automatic_optimization = False + self.embeddings: List[Tensor] = [] @staticmethod def configure_optimizers(): """PADIM doesn't require optimization, therefore returns no optimizers.""" return None - def training_step(self, batch, _): # pylint: disable=arguments-differ + def training_step(self, batch, _batch_idx): # pylint: disable=arguments-differ """Training Step of PADIM. For each batch, hierarchical features are extracted from the CNN. Args: - batch (Dict[str,Tensor]): Input batch - _: Index of the batch. + batch (Dict[str, Any]): Batch containing image filename, image, label and mask + _batch_idx: Index of the batch. Returns: Hierarchical feature map """ - self.model.feature_extractor.eval() - embeddings = self.model(batch["image"]) - return {"embeddings": embeddings.cpu()} - - def training_epoch_end(self, outputs: List[Dict[str, Tensor]]) -> None: - """Fit a multivariate gaussian model on an embedding extracted from deep hierarchical CNN features. - - Args: - outputs (List[Dict[str, Tensor]]): Batch of outputs from the training step - - Returns: - None - """ - - embeddings = torch.vstack([x["embeddings"] for x in outputs]) + embedding = self.model(batch["image"]) + + # NOTE: `self.embedding` appends each batch embedding to + # store the training set embedding. We manually append these + # values mainly due to the new order of hooks introduced after PL v1.4.0 + # https://github.com/PyTorchLightning/pytorch-lightning/pull/7357 + self.embeddings.append(embedding.cpu()) + + def on_validation_start(self) -> None: + """Fit a Gaussian to the embedding collected from the training set.""" + # NOTE: Previous anomalib versions fit Gaussian at the end of the epoch. + # This is not possible anymore with PyTorch Lightning v1.4.0 since validation + # is run within train epoch. + embeddings = torch.vstack(self.embeddings) self.stats = self.model.gaussian.fit(embeddings) def validation_step(self, batch, _): # pylint: disable=arguments-differ @@ -341,5 +341,4 @@ def validation_step(self, batch, _): # pylint: disable=arguments-differ """ batch["anomaly_maps"] = self.model(batch["image"]) - return batch diff --git a/anomalib/models/patchcore/config.yaml b/anomalib/models/patchcore/config.yaml index bccc702d58..ebb1755993 100644 --- a/anomalib/models/patchcore/config.yaml +++ b/anomalib/models/patchcore/config.yaml @@ -4,7 +4,7 @@ dataset: path: ./datasets/MVTec url: ftp://guest:GU.205dldo@ftp.softronics.ch/mvtec_anomaly_detection/mvtec_anomaly_detection.tar.xz task: segmentation - category: carpet + category: bottle label_format: None tiling: apply: false @@ -46,18 +46,15 @@ trainer: accelerator: null accumulate_grad_batches: 1 amp_backend: native - amp_level: O2 auto_lr_find: false auto_scale_batch_size: false auto_select_gpus: false benchmark: false - check_val_every_n_epoch: 1 + check_val_every_n_epoch: 1 # Don't validate before extracting features. checkpoint_callback: true default_root_dir: null - deterministic: true - distributed_backend: null + deterministic: false fast_dev_run: false - flush_logs_every_n_steps: 100 gpus: 1 gradient_clip_val: 0 limit_predict_batches: 1.0 @@ -67,7 +64,7 @@ trainer: log_every_n_steps: 50 log_gpu_memory: null max_epochs: 1 - max_steps: null + max_steps: -1 min_epochs: null min_steps: null move_metrics_to_cpu: false @@ -82,14 +79,12 @@ trainer: process_position: 0 profiler: null progress_bar_refresh_rate: null - reload_dataloaders_every_epoch: false replace_sampler_ddp: true stochastic_weight_avg: false sync_batchnorm: false terminate_on_nan: false tpu_cores: null track_grad_norm: -1 - truncated_bptt_steps: null - val_check_interval: 1.0 + val_check_interval: 1.0 # Don't validate before extracting features. weights_save_path: null weights_summary: top diff --git a/anomalib/models/patchcore/model.py b/anomalib/models/patchcore/model.py index e3dbe9c549..b497b831ec 100644 --- a/anomalib/models/patchcore/model.py +++ b/anomalib/models/patchcore/model.py @@ -216,8 +216,11 @@ def subsample_embedding(self, embedding: torch.Tensor, sampling_ratio: float) -> """ # Coreset Subsampling + print("Creating CoreSet Sampler via k-Center Greedy") sampler = KCenterGreedy(embedding=embedding, sampling_ratio=sampling_ratio) + print("Getting the coreset from the main embedding.") coreset = sampler.sample_coreset() + print("Assigning the coreset as the memory bank.") self.memory_bank = coreset def nearest_neighbors(self, embedding: Tensor, n_neighbors: int = 9) -> Tensor: @@ -250,7 +253,7 @@ class PatchcoreLightning(AnomalyModule): def __init__(self, hparams) -> None: super().__init__(hparams) - self.model = PatchcoreModel( + self.model: PatchcoreModel = PatchcoreModel( layers=hparams.model.layers, input_size=hparams.model.input_size, tile_size=hparams.dataset.tiling.tile_size, @@ -259,6 +262,7 @@ def __init__(self, hparams) -> None: apply_tiling=hparams.dataset.tiling.apply, ) self.automatic_optimization = False + self.embeddings: List[Tensor] = [] def configure_optimizers(self) -> None: """Configure optimizers. @@ -268,13 +272,12 @@ def configure_optimizers(self) -> None: """ return None - def training_step(self, batch, _): # pylint: disable=arguments-differ + def training_step(self, batch, _batch_idx): # pylint: disable=arguments-differ """Generate feature embedding of the batch. Args: - batch (Dict[str, Any]): Batch containing image filename, - image, label and mask - _ (int): Batch Index + batch (Dict[str, Any]): Batch containing image filename, image, label and mask + _batch_idx (int): Batch Index Returns: Dict[str, np.ndarray]: Embedding Vector @@ -282,20 +285,22 @@ def training_step(self, batch, _): # pylint: disable=arguments-differ self.model.feature_extractor.eval() embedding = self.model(batch["image"]) - return {"embedding": embedding} - - def training_epoch_end(self, outputs): - """Concatenate batch embeddings to generate normal embedding. + # NOTE: `self.embedding` appends each batch embedding to + # store the training set embedding. We manually append these + # values mainly due to the new order of hooks introduced after PL v1.4.0 + # https://github.com/PyTorchLightning/pytorch-lightning/pull/7357 + self.embeddings.append(embedding) - Apply coreset subsampling to the embedding set for dimensionality reduction. + def on_validation_start(self) -> None: + """Apply subsampling to the embedding collected from the training set.""" + # NOTE: Previous anomalib versions fit subsampling at the end of the epoch. + # This is not possible anymore with PyTorch Lightning v1.4.0 since validation + # is run within train epoch. + print("Aggregating the embedding extracted from the training set.") + embeddings = torch.vstack(self.embeddings) - Args: - outputs (List[Dict[str, np.ndarray]]): List of embedding vectors - """ - embedding = torch.vstack([output["embedding"] for output in outputs]) sampling_ratio = self.hparams.model.coreset_sampling_ratio - - self.model.subsample_embedding(embedding, sampling_ratio) + self.model.subsample_embedding(embeddings, sampling_ratio) def validation_step(self, batch, _): # pylint: disable=arguments-differ """Get batch of anomaly maps from input image batch. diff --git a/anomalib/models/stfpm/config.yaml b/anomalib/models/stfpm/config.yaml index 3d96c0f712..1f476e9f41 100644 --- a/anomalib/models/stfpm/config.yaml +++ b/anomalib/models/stfpm/config.yaml @@ -71,18 +71,15 @@ trainer: accelerator: null accumulate_grad_batches: 1 amp_backend: native - amp_level: O2 auto_lr_find: false auto_scale_batch_size: false auto_select_gpus: false benchmark: false - check_val_every_n_epoch: 2 + check_val_every_n_epoch: 1 checkpoint_callback: true default_root_dir: null - deterministic: true - distributed_backend: null + deterministic: false fast_dev_run: false - flush_logs_every_n_steps: 100 gpus: 1 gradient_clip_val: 0 limit_predict_batches: 1.0 @@ -92,7 +89,7 @@ trainer: log_every_n_steps: 50 log_gpu_memory: null max_epochs: 100 - max_steps: null + max_steps: -1 min_epochs: null min_steps: null move_metrics_to_cpu: false @@ -107,14 +104,12 @@ trainer: process_position: 0 profiler: null progress_bar_refresh_rate: null - reload_dataloaders_every_epoch: false replace_sampler_ddp: true stochastic_weight_avg: false sync_batchnorm: false terminate_on_nan: false tpu_cores: null track_grad_norm: -1 - truncated_bptt_steps: null val_check_interval: 1.0 weights_save_path: null weights_summary: top diff --git a/requirements/base.txt b/requirements/base.txt index 3d17b910f8..a800fb288e 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -10,8 +10,8 @@ networkx~=2.5 nncf==2.0.0 numpy~=1.19.5 omegaconf==2.1.1 -pillow==8.3.2 -pytorch-lightning==1.3.6 +pillow==9.0.0 +pytorch-lightning==1.5.9 torch==1.8.1 torchvision==0.9.1 scikit-image>=0.17.2 diff --git a/tests/core/callbacks/normalization_callback/test_normalization_callback.py b/tests/core/callbacks/normalization_callback/test_normalization_callback.py index 14237edfda..bfbe924b68 100644 --- a/tests/core/callbacks/normalization_callback/test_normalization_callback.py +++ b/tests/core/callbacks/normalization_callback/test_normalization_callback.py @@ -11,6 +11,7 @@ def run_train_test(config): model = get_model(config) datamodule = get_datamodule(config) callbacks = get_callbacks(config) + trainer = Trainer(**config.trainer, callbacks=callbacks) trainer.fit(model=model, datamodule=datamodule) results = trainer.test(model=model, datamodule=datamodule) @@ -21,6 +22,7 @@ def test_normalizer(): config = get_configurable_parameters(model_config_path="anomalib/models/padim/config.yaml") config.dataset.path = get_dataset_path(config.dataset.path) config.model.threshold.adaptive = True + config.project.log_images_to = [] # run without normalization config.model.normalization_method = "none" diff --git a/tests/models/test_model.py b/tests/models/test_model.py index edd825d844..2de089c305 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -85,6 +85,7 @@ def _setup(self, model_name, use_mvtec, dataset_path, project_path, nncf, catego config.dataset.category = category config.dataset.path = dataset_path config.model.weight_file = "weights/model.ckpt" # add model weights to the config + config.project.log_images_to = [] if not use_mvtec: config.dataset.category = "shapes"