diff --git a/anomalib/data/__init__.py b/anomalib/data/__init__.py index 8c295a1061..f1691620f5 100644 --- a/anomalib/data/__init__.py +++ b/anomalib/data/__init__.py @@ -7,7 +7,8 @@ from typing import Union from omegaconf import DictConfig, ListConfig -from pytorch_lightning import LightningDataModule + +from anomalib.data.base import AnomalibDataModule from .btech import BTech from .folder import Folder @@ -17,7 +18,7 @@ logger = logging.getLogger(__name__) -def get_datamodule(config: Union[DictConfig, ListConfig]) -> LightningDataModule: +def get_datamodule(config: Union[DictConfig, ListConfig]) -> AnomalibDataModule: """Get Anomaly Datamodule. Args: @@ -28,7 +29,7 @@ def get_datamodule(config: Union[DictConfig, ListConfig]) -> LightningDataModule """ logger.info("Loading the datamodule") - datamodule: LightningDataModule + datamodule: AnomalibDataModule if config.dataset.format.lower() == "mvtec": datamodule = MVTec( diff --git a/anomalib/data/base.py b/anomalib/data/base.py new file mode 100644 index 0000000000..ed61a83c9c --- /dev/null +++ b/anomalib/data/base.py @@ -0,0 +1,207 @@ +"""Anomalib dataset and datamodule base classes.""" + +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging +from abc import ABC, abstractmethod +from typing import Dict, Optional, Tuple, Union + +import albumentations as A +import cv2 +import numpy as np +from pandas import DataFrame +from pytorch_lightning import LightningDataModule +from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS +from torch import Tensor +from torch.utils.data import DataLoader, Dataset + +from anomalib.data.utils import read_image +from anomalib.pre_processing import PreProcessor + +logger = logging.getLogger(__name__) + + +class AnomalibDataset(Dataset): + """Anomalib dataset.""" + + def __init__(self, samples: DataFrame, task: str, split: str, pre_process: PreProcessor): + super().__init__() + self.samples = samples + self.task = task + self.split = split + self.pre_process = pre_process + + def __len__(self) -> int: + """Get length of the dataset.""" + return len(self.samples) + + def __getitem__(self, index: int) -> Dict[str, Union[str, Tensor]]: + """Get dataset item for the index ``index``. + + Args: + index (int): Index to get the item. + + Returns: + Union[Dict[str, Tensor], Dict[str, Union[str, Tensor]]]: Dict of image tensor during training. + Otherwise, Dict containing image path, target path, image tensor, label and transformed bounding box. + """ + image_path = self.samples.iloc[index].image_path + image = read_image(image_path) + label_index = self.samples.iloc[index].label_index + + item = dict(image_path=image_path, label=label_index) + + if self.task == "classification": + pre_processed = self.pre_process(image=image) + elif self.task == "segmentation": + mask_path = self.samples.iloc[index].mask_path + + # Only Anomalous (1) images have masks in anomaly datasets + # Therefore, create empty mask for Normal (0) images. + if label_index == 0: + mask = np.zeros(shape=image.shape[:2]) + else: + mask = cv2.imread(mask_path, flags=0) / 255.0 + + pre_processed = self.pre_process(image=image, mask=mask) + + item["mask_path"] = mask_path + item["mask"] = pre_processed["mask"] + else: + raise ValueError(f"Unknown task type: {self.task}") + item["image"] = pre_processed["image"] + + return item + + +class AnomalibDataModule(LightningDataModule, ABC): + """Base Anomalib data module.""" + + def __init__( + self, + task: str, + train_batch_size: int, + test_batch_size: int, + num_workers: int, + transform_config_train: Optional[Union[str, A.Compose]] = None, + transform_config_val: Optional[Union[str, A.Compose]] = None, + image_size: Optional[Union[int, Tuple[int, int]]] = None, + create_validation_set: bool = False, + ): + super().__init__() + self.task = task + self.train_batch_size = train_batch_size + self.test_batch_size = test_batch_size + self.num_workers = num_workers + self.create_validation_set = create_validation_set + + if transform_config_train is not None and transform_config_val is None: + transform_config_val = transform_config_train + self.pre_process_train = PreProcessor(config=transform_config_train, image_size=image_size) + self.pre_process_val = PreProcessor(config=transform_config_val, image_size=image_size) + + self.train_data: Optional[AnomalibDataset] = None + self.val_data: Optional[AnomalibDataset] = None + self.test_data: Optional[AnomalibDataset] = None + + self._samples: Optional[DataFrame] = None + + @abstractmethod + def _create_samples(self) -> DataFrame: + """This method should be implemented in the subclass. + + This method should return a dataframe that contains the information needed by the dataloader to load each of + the dataset items into memory. The dataframe must at least contain the following columns: + split - The subset to which the dataset item is assigned. + image_path - Path to file system location where the image is stored. + label_index - Index of the anomaly label, typically 0 for "normal" and 1 for "anomalous". + + Additionally, when the task type is segmentation, the dataframe must have the mask_path column, which contains + the path the ground truth masks (for the anomalous images only). + + Example of a dataframe returned by calling this method from a concrete class: + |---|-------------------|-----------|-------------|------------------|-------| + | | image_path | label | label_index | mask_path | split | + |---|-------------------|-----------|-------------|------------------|-------| + | 0 | path/to/image.png | anomalous | 0 | path/to/mask.png | train | + |---|-------------------|-----------|-------------|------------------|-------| + """ + raise NotImplementedError + + def get_samples(self, split: Optional[str] = None) -> DataFrame: + """Retrieve the samples of the full dataset or one of the splits (train, val, test). + + Args: + split: (str): The split for which we want to retrieve the samples ("train", "val" or "test"). When + left empty, all samples will be returned. + + Returns: + DataFrame: A dataframe containing the samples of the split or full dataset. + """ + assert self._samples is not None, "Samples have not been created yet." + if split is None: + return self._samples + samples = self._samples[self._samples.split == split] + return samples.reset_index(drop=True) + + def setup(self, stage: Optional[str] = None) -> None: + """Setup train, validation and test data. + + Args: + stage: Optional[str]: Train/Val/Test stages. (Default value = None) + """ + self._samples = self._create_samples() + + logger.info("Setting up train, validation, test and prediction datasets.") + if stage in (None, "fit"): + samples = self.get_samples("train") + self.train_data = AnomalibDataset( + samples=samples, + split="train", + task=self.task, + pre_process=self.pre_process_train, + ) + + if stage in (None, "fit", "validate"): + samples = self.get_samples("val") if self.create_validation_set else self.get_samples("test") + self.val_data = AnomalibDataset( + samples=samples, + split="val", + task=self.task, + pre_process=self.pre_process_val, + ) + + if stage in (None, "test"): + samples = self.get_samples("test") + self.test_data = AnomalibDataset( + samples=samples, + split="test", + task=self.task, + pre_process=self.pre_process_val, + ) + + def contains_anomalous_images(self, split: Optional[str] = None) -> bool: + """Check if the dataset or the specified subset contains any anomalous images. + + Args: + split (str): the subset of interest ("train", "val" or "test"). When left empty, the full dataset will be + checked. + + Returns: + bool: Boolean indicating if any anomalous images have been assigned to the dataset or subset. + """ + samples = self.get_samples(split) + return 1 in list(samples.label_index) + + def train_dataloader(self) -> TRAIN_DATALOADERS: + """Get train dataloader.""" + return DataLoader(self.train_data, shuffle=True, batch_size=self.train_batch_size, num_workers=self.num_workers) + + def val_dataloader(self) -> EVAL_DATALOADERS: + """Get validation dataloader.""" + return DataLoader(self.val_data, shuffle=False, batch_size=self.test_batch_size, num_workers=self.num_workers) + + def test_dataloader(self) -> EVAL_DATALOADERS: + """Get test dataloader.""" + return DataLoader(self.test_data, shuffle=False, batch_size=self.test_batch_size, num_workers=self.num_workers) diff --git a/anomalib/data/btech.py b/anomalib/data/btech.py index 8b6bac792b..489841ab94 100644 --- a/anomalib/data/btech.py +++ b/anomalib/data/btech.py @@ -11,260 +11,36 @@ import logging import shutil -import warnings import zipfile from pathlib import Path -from typing import Dict, Optional, Tuple, Union +from typing import Optional, Tuple, Union from urllib.request import urlretrieve import albumentations as A import cv2 -import numpy as np import pandas as pd from pandas.core.frame import DataFrame -from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.utilities.cli import DATAMODULE_REGISTRY -from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS -from torch import Tensor -from torch.utils.data import DataLoader -from torch.utils.data.dataset import Dataset -from torchvision.datasets.folder import VisionDataset from tqdm import tqdm -from anomalib.data.inference import InferenceDataset -from anomalib.data.utils import DownloadProgressBar, hash_check, read_image +from anomalib.data.base import AnomalibDataModule +from anomalib.data.utils import DownloadProgressBar, hash_check from anomalib.data.utils.split import ( create_validation_set_from_test_set, split_normal_images_in_train_set, ) -from anomalib.pre_processing import PreProcessor logger = logging.getLogger(__name__) -def make_btech_dataset( - path: Path, - split: Optional[str] = None, - split_ratio: float = 0.1, - seed: Optional[int] = None, - create_validation_set: bool = False, -) -> DataFrame: - """Create BTech samples by parsing the BTech data file structure. - - The files are expected to follow the structure: - path/to/dataset/split/category/image_filename.png - path/to/dataset/ground_truth/category/mask_filename.png - - Args: - path (Path): Path to dataset - split (str, optional): Dataset split (ie., either train or test). Defaults to None. - split_ratio (float, optional): Ratio to split normal training images and add to the - test set in case test set doesn't contain any normal images. - Defaults to 0.1. - seed (int, optional): Random seed to ensure reproducibility when splitting. Defaults to 0. - create_validation_set (bool, optional): Boolean to create a validation set from the test set. - BTech dataset does not contain a validation set. Those wanting to create a validation set - could set this flag to ``True``. - - Example: - The following example shows how to get training samples from BTech 01 category: - - >>> root = Path('./BTech') - >>> category = '01' - >>> path = root / category - >>> path - PosixPath('BTech/01') - - >>> samples = make_btech_dataset(path, split='train', split_ratio=0.1, seed=0) - >>> samples.head() - path split label image_path mask_path label_index - 0 BTech/01 train 01 BTech/01/train/ok/105.bmp BTech/01/ground_truth/ok/105.png 0 - 1 BTech/01 train 01 BTech/01/train/ok/017.bmp BTech/01/ground_truth/ok/017.png 0 - ... - - Returns: - DataFrame: an output dataframe containing samples for the requested split (ie., train or test) - """ - samples_list = [ - (str(path),) + filename.parts[-3:] for filename in path.glob("**/*") if filename.suffix in (".bmp", ".png") - ] - if len(samples_list) == 0: - raise RuntimeError(f"Found 0 images in {path}") - - samples = pd.DataFrame(samples_list, columns=["path", "split", "label", "image_path"]) - samples = samples[samples.split != "ground_truth"] - - # Create mask_path column - samples["mask_path"] = ( - samples.path - + "/ground_truth/" - + samples.label - + "/" - + samples.image_path.str.rstrip("png").str.rstrip(".") - + ".png" - ) - - # Modify image_path column by converting to absolute path - samples["image_path"] = samples.path + "/" + samples.split + "/" + samples.label + "/" + samples.image_path - - # Split the normal images in training set if test set doesn't - # contain any normal images. This is needed because AUC score - # cannot be computed based on 1-class - if sum((samples.split == "test") & (samples.label == "ok")) == 0: - samples = split_normal_images_in_train_set(samples, split_ratio, seed) - - # Good images don't have mask - samples.loc[(samples.split == "test") & (samples.label == "ok"), "mask_path"] = "" - - # Create label index for normal (0) and anomalous (1) images. - samples.loc[(samples.label == "ok"), "label_index"] = 0 - samples.loc[(samples.label != "ok"), "label_index"] = 1 - samples.label_index = samples.label_index.astype(int) - - if create_validation_set: - samples = create_validation_set_from_test_set(samples, seed=seed) - - # Get the data frame for the split. - if split is not None and split in ["train", "val", "test"]: - samples = samples[samples.split == split] - samples = samples.reset_index(drop=True) - - return samples - - -class BTechDataset(VisionDataset): - """BTech PyTorch Dataset.""" - - def __init__( - self, - root: Union[Path, str], - category: str, - pre_process: PreProcessor, - split: str, - task: str = "segmentation", - seed: Optional[int] = None, - create_validation_set: bool = False, - ) -> None: - """Btech Dataset class. - - Args: - root: Path to the BTech dataset - category: Name of the BTech category. - pre_process: List of pre_processing object containing albumentation compose. - split: 'train', 'val' or 'test' - task: ``classification`` or ``segmentation`` - seed: seed used for the random subset splitting - create_validation_set: Create a validation subset in addition to the train and test subsets - - Examples: - >>> from anomalib.data.btech import BTechDataset - >>> from anomalib.data.transforms import PreProcessor - >>> pre_process = PreProcessor(image_size=256) - >>> dataset = BTechDataset( - ... root='./datasets/BTech', - ... category='leather', - ... pre_process=pre_process, - ... task="classification", - ... is_train=True, - ... ) - >>> dataset[0].keys() - dict_keys(['image']) - - >>> dataset.split = "test" - >>> dataset[0].keys() - dict_keys(['image', 'image_path', 'label']) - - >>> dataset.task = "segmentation" - >>> dataset.split = "train" - >>> dataset[0].keys() - dict_keys(['image']) - - >>> dataset.split = "test" - >>> dataset[0].keys() - dict_keys(['image_path', 'label', 'mask_path', 'image', 'mask']) - - >>> dataset[0]["image"].shape, dataset[0]["mask"].shape - (torch.Size([3, 256, 256]), torch.Size([256, 256])) - """ - super().__init__(root) - - if seed is None: - warnings.warn( - "seed is None." - " When seed is not set, images from the normal directory are split between training and test dir." - " This will lead to inconsistency between runs." - ) - - self.root = Path(root) if isinstance(root, str) else root - self.category: str = category - self.split = split - self.task = task - - self.pre_process = pre_process - - self.samples = make_btech_dataset( - path=self.root / category, - split=self.split, - seed=seed, - create_validation_set=create_validation_set, - ) - - def __len__(self) -> int: - """Get length of the dataset.""" - return len(self.samples) - - def __getitem__(self, index: int) -> Dict[str, Union[str, Tensor]]: - """Get dataset item for the index ``index``. - - Args: - index (int): Index to get the item. - - Returns: - Union[Dict[str, Tensor], Dict[str, Union[str, Tensor]]]: Dict of image tensor during training. - Otherwise, Dict containing image path, target path, image tensor, label and transformed bounding box. - """ - item: Dict[str, Union[str, Tensor]] = {} - - image_path = self.samples.image_path[index] - image = read_image(image_path) - - pre_processed = self.pre_process(image=image) - item = {"image": pre_processed["image"]} - - if self.split in ["val", "test"]: - label_index = self.samples.label_index[index] - - item["image_path"] = image_path - item["label"] = label_index - - if self.task == "segmentation": - mask_path = self.samples.mask_path[index] - - # Only Anomalous (1) images has masks in BTech dataset. - # Therefore, create empty mask for Normal (0) images. - if label_index == 0: - mask = np.zeros(shape=image.shape[:2]) - else: - mask = cv2.imread(mask_path, flags=0) / 255.0 - - pre_processed = self.pre_process(image=image, mask=mask) - - item["mask_path"] = mask_path - item["image"] = pre_processed["image"] - item["mask"] = pre_processed["mask"] - - return item - - @DATAMODULE_REGISTRY -class BTech(LightningDataModule): +class BTech(AnomalibDataModule): """BTechDataModule Lightning Data Module.""" def __init__( self, root: str, category: str, - # TODO: Remove default values. IAAALD-211 image_size: Optional[Union[int, Tuple[int, int]]] = None, train_batch_size: int = 32, test_batch_size: int = 32, @@ -272,6 +48,7 @@ def __init__( task: str = "segmentation", transform_config_train: Optional[Union[str, A.Compose]] = None, transform_config_val: Optional[Union[str, A.Compose]] = None, + split_ratio: float = 0.2, seed: Optional[int] = None, create_validation_set: bool = False, ) -> None: @@ -316,34 +93,24 @@ def __init__( >>> data["image"].shape, data["mask"].shape (torch.Size([32, 3, 256, 256]), torch.Size([32, 256, 256])) """ - super().__init__() - self.root = root if isinstance(root, Path) else Path(root) self.category = category - self.dataset_path = self.root / self.category - self.transform_config_train = transform_config_train - self.transform_config_val = transform_config_val - self.image_size = image_size - - if self.transform_config_train is not None and self.transform_config_val is None: - self.transform_config_val = self.transform_config_train - - self.pre_process_train = PreProcessor(config=self.transform_config_train, image_size=self.image_size) - self.pre_process_val = PreProcessor(config=self.transform_config_val, image_size=self.image_size) - - self.train_batch_size = train_batch_size - self.test_batch_size = test_batch_size - self.num_workers = num_workers + self.path = self.root / self.category self.create_validation_set = create_validation_set - self.task = task self.seed = seed - - self.train_data: Dataset - self.test_data: Dataset - if create_validation_set: - self.val_data: Dataset - self.inference_data: Dataset + self.split_ratio = split_ratio + + super().__init__( + task=task, + train_batch_size=train_batch_size, + test_batch_size=test_batch_size, + num_workers=num_workers, + transform_config_train=transform_config_train, + transform_config_val=transform_config_val, + image_size=image_size, + create_validation_set=create_validation_set, + ) def prepare_data(self) -> None: """Download the dataset if not available.""" @@ -386,69 +153,62 @@ def prepare_data(self) -> None: logger.info("Cleaning the tar file") zip_filename.unlink() - def setup(self, stage: Optional[str] = None) -> None: - """Setup train, validation and test data. + def _create_samples(self) -> DataFrame: + """Create BTech samples by parsing the BTech data file structure. - BTech dataset uses BTech dataset structure, which is the reason for - using `anomalib.data.btech.BTech` class to get the dataset items. + The files are expected to follow the structure: + path/to/dataset/category/split/[ok|ko]/image_filename.bmp + path/to/dataset/category/ground_truth/ko/mask_filename.png - Args: - stage: Optional[str]: Train/Val/Test stages. (Default value = None) + This function creates a dataframe to store the parsed information based on the following format: + |---|---------------|-------|---------|---------------|---------------------------------------|-------------| + | | path | split | label | image_path | mask_path | label_index | + |---|---------------|-------|---------|---------------|---------------------------------------|-------------| + | 0 | datasets/name | test | ko | filename.png | ground_truth/ko/filename_mask.png | 1 | + |---|---------------|-------|---------|---------------|---------------------------------------|-------------| + Returns: + DataFrame: an output dataframe containing the samples of the dataset. """ - logger.info("Setting up train, validation, test and prediction datasets.") - if stage in (None, "fit"): - self.train_data = BTechDataset( - root=self.root, - category=self.category, - pre_process=self.pre_process_train, - split="train", - task=self.task, - seed=self.seed, - create_validation_set=self.create_validation_set, - ) - - if self.create_validation_set: - self.val_data = BTechDataset( - root=self.root, - category=self.category, - pre_process=self.pre_process_val, - split="val", - task=self.task, - seed=self.seed, - create_validation_set=self.create_validation_set, - ) - - self.test_data = BTechDataset( - root=self.root, - category=self.category, - pre_process=self.pre_process_val, - split="test", - task=self.task, - seed=self.seed, - create_validation_set=self.create_validation_set, + samples_list = [ + (str(self.path),) + filename.parts[-3:] + for filename in self.path.glob("**/*") + if filename.suffix in (".bmp", ".png") + ] + if len(samples_list) == 0: + raise RuntimeError(f"Found 0 images in {self.path}") + + samples = pd.DataFrame(samples_list, columns=["path", "split", "label", "image_path"]) + samples = samples[samples.split != "ground_truth"] + + # Create mask_path column + samples["mask_path"] = ( + samples.path + + "/ground_truth/" + + samples.label + + "/" + + samples.image_path.str.rstrip("bmp|png").str.rstrip(".") + + ".png" ) - if stage == "predict": - self.inference_data = InferenceDataset( - path=self.root, image_size=self.image_size, transform_config=self.transform_config_val - ) + # Modify image_path column by converting to absolute path + samples["image_path"] = samples.path + "/" + samples.split + "/" + samples.label + "/" + samples.image_path - def train_dataloader(self) -> TRAIN_DATALOADERS: - """Get train dataloader.""" - return DataLoader(self.train_data, shuffle=True, batch_size=self.train_batch_size, num_workers=self.num_workers) + # Split the normal images in training set if test set doesn't + # contain any normal images. This is needed because AUC score + # cannot be computed based on 1-class + if sum((samples.split == "test") & (samples.label == "ok")) == 0: + samples = split_normal_images_in_train_set(samples, self.split_ratio, self.seed) - def val_dataloader(self) -> EVAL_DATALOADERS: - """Get validation dataloader.""" - dataset = self.val_data if self.create_validation_set else self.test_data - return DataLoader(dataset=dataset, shuffle=False, batch_size=self.test_batch_size, num_workers=self.num_workers) + # Good images don't have mask + samples.loc[(samples.split == "test") & (samples.label == "ok"), "mask_path"] = "" - def test_dataloader(self) -> EVAL_DATALOADERS: - """Get test dataloader.""" - return DataLoader(self.test_data, shuffle=False, batch_size=self.test_batch_size, num_workers=self.num_workers) + # Create label index for normal (0) and anomalous (1) images. + samples.loc[(samples.label == "ok"), "label_index"] = 0 + samples.loc[(samples.label != "ok"), "label_index"] = 1 + samples.label_index = samples.label_index.astype(int) - def predict_dataloader(self) -> EVAL_DATALOADERS: - """Get predict dataloader.""" - return DataLoader( - self.inference_data, shuffle=False, batch_size=self.test_batch_size, num_workers=self.num_workers - ) + if self.create_validation_set: + samples = create_validation_set_from_test_set(samples, seed=self.seed) + + return samples diff --git a/anomalib/data/folder.py b/anomalib/data/folder.py index 0f3b47adbd..22f6257308 100644 --- a/anomalib/data/folder.py +++ b/anomalib/data/folder.py @@ -9,26 +9,18 @@ import logging import warnings from pathlib import Path -from typing import Dict, Optional, Tuple, Union +from typing import Optional, Tuple, Union import albumentations as A -import cv2 -import numpy as np from pandas.core.frame import DataFrame -from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.utilities.cli import DATAMODULE_REGISTRY -from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS -from torch import Tensor -from torch.utils.data import DataLoader, Dataset from torchvision.datasets.folder import IMG_EXTENSIONS -from anomalib.data.inference import InferenceDataset -from anomalib.data.utils import read_image +from anomalib.data.base import AnomalibDataModule from anomalib.data.utils.split import ( create_validation_set_from_test_set, split_normal_images_in_train_set, ) -from anomalib.pre_processing import PreProcessor logger = logging.getLogger(__name__) @@ -77,222 +69,8 @@ def _prepare_files_labels( return filenames, labels -def make_dataset( - normal_dir: Union[str, Path], - abnormal_dir: Union[str, Path], - normal_test_dir: Optional[Union[str, Path]] = None, - mask_dir: Optional[Union[str, Path]] = None, - split: Optional[str] = None, - split_ratio: float = 0.2, - seed: Optional[int] = None, - create_validation_set: bool = True, - extensions: Optional[Tuple[str, ...]] = None, -): - """Make Folder Dataset. - - Args: - normal_dir (Union[str, Path]): Path to the directory containing normal images. - abnormal_dir (Union[str, Path]): Path to the directory containing abnormal images. - normal_test_dir (Optional[Union[str, Path]], optional): Path to the directory containing - normal images for the test dataset. Normal test images will be a split of `normal_dir` - if `None`. Defaults to None. - mask_dir (Optional[Union[str, Path]], optional): Path to the directory containing - the mask annotations. Defaults to None. - split (Optional[str], optional): Dataset split (ie., either train or test). Defaults to None. - split_ratio (float, optional): Ratio to split normal training images and add to the - test set in case test set doesn't contain any normal images. - Defaults to 0.2. - seed (int, optional): Random seed to ensure reproducibility when splitting. Defaults to 0. - create_validation_set (bool, optional):Boolean to create a validation set from the test set. - Those wanting to create a validation set could set this flag to ``True``. - extensions (Optional[Tuple[str, ...]], optional): Type of the image extensions to read from the - directory. - - Returns: - DataFrame: an output dataframe containing samples for the requested split (ie., train or test) - """ - - filenames = [] - labels = [] - dirs = {"normal": normal_dir, "abnormal": abnormal_dir} - - if normal_test_dir: - dirs = {**dirs, **{"normal_test": normal_test_dir}} - - for dir_type, path in dirs.items(): - filename, label = _prepare_files_labels(path, dir_type, extensions) - filenames += filename - labels += label - - samples = DataFrame({"image_path": filenames, "label": labels}) - - # Create label index for normal (0) and abnormal (1) images. - samples.loc[(samples.label == "normal") | (samples.label == "normal_test"), "label_index"] = 0 - samples.loc[(samples.label == "abnormal"), "label_index"] = 1 - samples.label_index = samples.label_index.astype(int) - - # If a path to mask is provided, add it to the sample dataframe. - if mask_dir is not None: - mask_dir = _check_and_convert_path(mask_dir) - samples["mask_path"] = "" - for index, row in samples.iterrows(): - if row.label_index == 1: - samples.loc[index, "mask_path"] = str(mask_dir / row.image_path.name) - - # Ensure the pathlib objects are converted to str. - # This is because torch dataloader doesn't like pathlib. - samples = samples.astype({"image_path": "str"}) - - # Create train/test split. - # By default, all the normal samples are assigned as train. - # and all the abnormal samples are test. - samples.loc[(samples.label == "normal"), "split"] = "train" - samples.loc[(samples.label == "abnormal") | (samples.label == "normal_test"), "split"] = "test" - - if not normal_test_dir: - samples = split_normal_images_in_train_set( - samples=samples, split_ratio=split_ratio, seed=seed, normal_label="normal" - ) - - # If `create_validation_set` is set to True, the test set is split into half. - if create_validation_set: - samples = create_validation_set_from_test_set(samples, seed=seed, normal_label="normal") - - # Get the data frame for the split. - if split is not None and split in ["train", "val", "test"]: - samples = samples[samples.split == split] - samples = samples.reset_index(drop=True) - - return samples - - -class FolderDataset(Dataset): - """Folder Dataset.""" - - def __init__( - self, - normal_dir: Union[Path, str], - abnormal_dir: Union[Path, str], - split: str, - pre_process: PreProcessor, - normal_test_dir: Optional[Union[Path, str]] = None, - split_ratio: float = 0.2, - mask_dir: Optional[Union[Path, str]] = None, - extensions: Optional[Tuple[str, ...]] = None, - task: Optional[str] = None, - seed: Optional[int] = None, - create_validation_set: bool = False, - ) -> None: - """Create Folder Folder Dataset. - - Args: - normal_dir (Union[str, Path]): Path to the directory containing normal images. - abnormal_dir (Union[str, Path]): Path to the directory containing abnormal images. - split (Optional[str], optional): Dataset split (ie., either train or test). Defaults to None. - pre_process (Optional[PreProcessor], optional): Image Pro-processor to apply transform. - Defaults to None. - normal_test_dir (Optional[Union[str, Path]], optional): Path to the directory containing - normal images for the test dataset. Defaults to None. - split_ratio (float, optional): Ratio to split normal training images and add to the - test set in case test set doesn't contain any normal images. - Defaults to 0.2. - mask_dir (Optional[Union[str, Path]], optional): Path to the directory containing - the mask annotations. Defaults to None. - extensions (Optional[Tuple[str, ...]], optional): Type of the image extensions to read from the - directory. - task (Optional[str], optional): Task type. (classification or segmentation) Defaults to None. - seed (int, optional): Random seed to ensure reproducibility when splitting. Defaults to 0. - create_validation_set (bool, optional):Boolean to create a validation set from the test set. - Those wanting to create a validation set could set this flag to ``True``. - - Raises: - ValueError: When task is set to classification and `mask_dir` is provided. When `mask_dir` is - provided, `task` should be set to `segmentation`. - - """ - self.split = split - - if task == "segmentation" and mask_dir is None: - warnings.warn( - "Segmentation task is requested, but mask directory is not provided. " - "Classification is to be chosen if mask directory is not provided." - ) - self.task = "classification" - - if task == "classification" and mask_dir: - warnings.warn( - "Classification task is requested, but mask directory is provided. " - "Segmentation task is to be chosen if mask directory is provided." - ) - self.task = "segmentation" - - if task is None or mask_dir is None: - self.task = "classification" - else: - self.task = task - - self.pre_process = pre_process - self.samples = make_dataset( - normal_dir=normal_dir, - abnormal_dir=abnormal_dir, - normal_test_dir=normal_test_dir, - mask_dir=mask_dir, - split=split, - split_ratio=split_ratio, - seed=seed, - create_validation_set=create_validation_set, - extensions=extensions, - ) - - def __len__(self) -> int: - """Get length of the dataset.""" - return len(self.samples) - - def __getitem__(self, index: int) -> Dict[str, Union[str, Tensor]]: - """Get dataset item for the index ``index``. - - Args: - index (int): Index to get the item. - - Returns: - Union[Dict[str, Tensor], Dict[str, Union[str, Tensor]]]: Dict of image tensor during training. - Otherwise, Dict containing image path, target path, image tensor, label and transformed bounding box. - """ - item: Dict[str, Union[str, Tensor]] = {} - - image_path = self.samples.image_path[index] - image = read_image(image_path) - - pre_processed = self.pre_process(image=image) - item = {"image": pre_processed["image"]} - - if self.split in ["val", "test"]: - label_index = self.samples.label_index[index] - - item["image_path"] = image_path - item["label"] = label_index - - if self.task == "segmentation": - mask_path = self.samples.mask_path[index] - - # Only Anomalous (1) images has masks in MVTec AD dataset. - # Therefore, create empty mask for Normal (0) images. - if label_index == 0: - mask = np.zeros(shape=image.shape[:2]) - else: - mask = cv2.imread(mask_path, flags=0) / 255.0 - - pre_processed = self.pre_process(image=image, mask=mask) - - item["mask_path"] = mask_path - item["image"] = pre_processed["image"] - item["mask"] = pre_processed["mask"] - - return item - - @DATAMODULE_REGISTRY -class Folder(LightningDataModule): +class Folder(AnomalibDataModule): """Folder Lightning Data Module.""" def __init__( @@ -412,8 +190,6 @@ def __init__( torch.Size([12, 3, 256, 256]) torch.Size([12, 256, 256]) """ - super().__init__() - if seed is None and normal_test_dir is None: raise ValueError( "Both seed and normal_test_dir cannot be None." @@ -421,118 +197,103 @@ def __init__( " This will lead to inconsistency between runs." ) + if task == "segmentation" and mask_dir is None: + warnings.warn( + "Segmentation task is requested, but mask directory is not provided. " + "Classification is to be chosen if mask directory is not provided." + ) + self.task = "classification" + else: + self.task = task + self.root = _check_and_convert_path(root) self.normal_dir = self.root / normal_dir - self.abnormal_dir = self.root / abnormal_dir - self.normal_test = normal_test_dir + self.abnormal_dir = self.root / abnormal_dir if abnormal_dir is not None else None + self.normal_test_dir = normal_test_dir if normal_test_dir: - self.normal_test = self.root / normal_test_dir + self.normal_test_dir = self.root / normal_test_dir self.mask_dir = mask_dir self.extensions = extensions self.split_ratio = split_ratio - if task == "classification" and mask_dir is not None: - raise ValueError( - "Classification type is set but mask_dir provided. " - "If mask_dir is provided task type must be segmentation. " - "Check your configuration." - ) - self.task = task - self.transform_config_train = transform_config_train - self.transform_config_val = transform_config_val - self.image_size = image_size - - if self.transform_config_train is not None and self.transform_config_val is None: - self.transform_config_val = self.transform_config_train - - self.pre_process_train = PreProcessor(config=self.transform_config_train, image_size=self.image_size) - self.pre_process_val = PreProcessor(config=self.transform_config_val, image_size=self.image_size) - - self.train_batch_size = train_batch_size - self.test_batch_size = test_batch_size - self.num_workers = num_workers - self.create_validation_set = create_validation_set self.seed = seed - self.train_data: Dataset - self.test_data: Dataset - if create_validation_set: - self.val_data: Dataset - self.inference_data: Dataset + super().__init__( + task=task, + train_batch_size=train_batch_size, + test_batch_size=test_batch_size, + num_workers=num_workers, + transform_config_train=transform_config_train, + transform_config_val=transform_config_val, + image_size=image_size, + create_validation_set=create_validation_set, + ) - def setup(self, stage: Optional[str] = None) -> None: - """Setup train, validation and test data. + def _create_samples(self): + """Create the dataframe with samples for the Folder dataset. - Args: - stage: Optional[str]: Train/Val/Test stages. (Default value = None) + The files are expected to follow the structure: + path/to/dataset/normal_folder_name/normal_image_name.png + path/to/dataset/abnormal_folder_name/abnormal_image_name.png - """ - logger.info("Setting up train, validation, test and prediction datasets.") - if stage in (None, "fit"): - self.train_data = FolderDataset( - normal_dir=self.normal_dir, - abnormal_dir=self.abnormal_dir, - normal_test_dir=self.normal_test, - split="train", - split_ratio=self.split_ratio, - mask_dir=self.mask_dir, - pre_process=self.pre_process_train, - extensions=self.extensions, - task=self.task, - seed=self.seed, - create_validation_set=self.create_validation_set, - ) - if self.create_validation_set: - self.val_data = FolderDataset( - normal_dir=self.normal_dir, - abnormal_dir=self.abnormal_dir, - normal_test_dir=self.normal_test, - split="val", - split_ratio=self.split_ratio, - mask_dir=self.mask_dir, - pre_process=self.pre_process_val, - extensions=self.extensions, - task=self.task, - seed=self.seed, - create_validation_set=self.create_validation_set, - ) + This function creates a dataframe to store the parsed information based on the following format: + |---|-------------------|--------|-------------|------------------|-------| + | | image_path | label | label_index | mask_path | split | + |---|-------------------|--------|-------------|------------------|-------| + | 0 | path/to/image.png | normal | 0 | path/to/mask.png | train | + |---|-------------------|--------|-------------|------------------|-------| - self.test_data = FolderDataset( - normal_dir=self.normal_dir, - abnormal_dir=self.abnormal_dir, - split="test", - normal_test_dir=self.normal_test, - split_ratio=self.split_ratio, - mask_dir=self.mask_dir, - pre_process=self.pre_process_val, - extensions=self.extensions, - task=self.task, - seed=self.seed, - create_validation_set=self.create_validation_set, - ) + Returns: + DataFrame: an output dataframe containing the samples of the dataset. + """ - if stage == "predict": - self.inference_data = InferenceDataset( - path=self.root, image_size=self.image_size, transform_config=self.transform_config_val + filenames = [] + labels = [] + dirs = {"normal": self.normal_dir, "abnormal": self.abnormal_dir} + + if self.normal_test_dir: + dirs = {**dirs, **{"normal_test": self.normal_test_dir}} + + for dir_type, path in dirs.items(): + if path is not None: + filename, label = _prepare_files_labels(path, dir_type, self.extensions) + filenames += filename + labels += label + + samples = DataFrame({"image_path": filenames, "label": labels}) + + # Create label index for normal (0) and abnormal (1) images. + samples.loc[(samples.label == "normal") | (samples.label == "normal_test"), "label_index"] = 0 + samples.loc[(samples.label == "abnormal"), "label_index"] = 1 + samples.label_index = samples.label_index.astype(int) + + # If a path to mask is provided, add it to the sample dataframe. + if self.mask_dir is not None: + self.mask_dir = _check_and_convert_path(self.mask_dir) + samples["mask_path"] = "" + for index, row in samples.iterrows(): + if row.label_index == 1: + samples.loc[index, "mask_path"] = str(self.mask_dir / row.image_path.name) + + # Ensure the pathlib objects are converted to str. + # This is because torch dataloader doesn't like pathlib. + samples = samples.astype({"image_path": "str"}) + + # Create train/test split. + # By default, all the normal samples are assigned as train. + # and all the abnormal samples are test. + samples.loc[(samples.label == "normal"), "split"] = "train" + samples.loc[(samples.label == "abnormal") | (samples.label == "normal_test"), "split"] = "test" + + if not self.normal_test_dir: + samples = split_normal_images_in_train_set( + samples=samples, split_ratio=self.split_ratio, seed=self.seed, normal_label="normal" ) - def train_dataloader(self) -> TRAIN_DATALOADERS: - """Get train dataloader.""" - return DataLoader(self.train_data, shuffle=True, batch_size=self.train_batch_size, num_workers=self.num_workers) - - def val_dataloader(self) -> EVAL_DATALOADERS: - """Get validation dataloader.""" - dataset = self.val_data if self.create_validation_set else self.test_data - return DataLoader(dataset=dataset, shuffle=False, batch_size=self.test_batch_size, num_workers=self.num_workers) - - def test_dataloader(self) -> EVAL_DATALOADERS: - """Get test dataloader.""" - return DataLoader(self.test_data, shuffle=False, batch_size=self.test_batch_size, num_workers=self.num_workers) + # If `create_validation_set` is set to True, the test set is split into half. + if self.create_validation_set: + samples = create_validation_set_from_test_set(samples, seed=self.seed, normal_label="normal") - def predict_dataloader(self) -> EVAL_DATALOADERS: - """Get predict dataloader.""" - return DataLoader( - self.inference_data, shuffle=False, batch_size=self.test_batch_size, num_workers=self.num_workers - ) + return samples diff --git a/anomalib/data/mvtec.py b/anomalib/data/mvtec.py index 9b45699d64..1772baf4f1 100644 --- a/anomalib/data/mvtec.py +++ b/anomalib/data/mvtec.py @@ -31,263 +31,32 @@ import tarfile import warnings from pathlib import Path -from typing import Dict, Optional, Tuple, Union +from typing import Optional, Tuple, Union from urllib.request import urlretrieve import albumentations as A -import cv2 -import numpy as np import pandas as pd from pandas.core.frame import DataFrame -from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.utilities.cli import DATAMODULE_REGISTRY -from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS -from torch import Tensor -from torch.utils.data import DataLoader -from torch.utils.data.dataset import Dataset -from torchvision.datasets.folder import VisionDataset - -from anomalib.data.inference import InferenceDataset -from anomalib.data.utils import DownloadProgressBar, hash_check, read_image + +from anomalib.data.base import AnomalibDataModule +from anomalib.data.utils import DownloadProgressBar, hash_check from anomalib.data.utils.split import ( create_validation_set_from_test_set, split_normal_images_in_train_set, ) -from anomalib.pre_processing import PreProcessor logger = logging.getLogger(__name__) -def make_mvtec_dataset( - path: Path, - split: Optional[str] = None, - split_ratio: float = 0.1, - seed: Optional[int] = None, - create_validation_set: bool = False, -) -> DataFrame: - """Create MVTec AD samples by parsing the MVTec AD data file structure. - - The files are expected to follow the structure: - path/to/dataset/split/category/image_filename.png - path/to/dataset/ground_truth/category/mask_filename.png - - This function creates a dataframe to store the parsed information based on the following format: - |---|---------------|-------|---------|---------------|---------------------------------------|-------------| - | | path | split | label | image_path | mask_path | label_index | - |---|---------------|-------|---------|---------------|---------------------------------------|-------------| - | 0 | datasets/name | test | defect | filename.png | ground_truth/defect/filename_mask.png | 1 | - |---|---------------|-------|---------|---------------|---------------------------------------|-------------| - - Args: - path (Path): Path to dataset - split (str, optional): Dataset split (ie., either train or test). Defaults to None. - split_ratio (float, optional): Ratio to split normal training images and add to the - test set in case test set doesn't contain any normal images. - Defaults to 0.1. - seed (int, optional): Random seed to ensure reproducibility when splitting. Defaults to 0. - create_validation_set (bool, optional): Boolean to create a validation set from the test set. - MVTec AD dataset does not contain a validation set. Those wanting to create a validation set - could set this flag to ``True``. - - Example: - The following example shows how to get training samples from MVTec AD bottle category: - - >>> root = Path('./MVTec') - >>> category = 'bottle' - >>> path = root / category - >>> path - PosixPath('MVTec/bottle') - - >>> samples = make_mvtec_dataset(path, split='train', split_ratio=0.1, seed=0) - >>> samples.head() - path split label image_path mask_path label_index - 0 MVTec/bottle train good MVTec/bottle/train/good/105.png MVTec/bottle/ground_truth/good/105_mask.png 0 - 1 MVTec/bottle train good MVTec/bottle/train/good/017.png MVTec/bottle/ground_truth/good/017_mask.png 0 - 2 MVTec/bottle train good MVTec/bottle/train/good/137.png MVTec/bottle/ground_truth/good/137_mask.png 0 - 3 MVTec/bottle train good MVTec/bottle/train/good/152.png MVTec/bottle/ground_truth/good/152_mask.png 0 - 4 MVTec/bottle train good MVTec/bottle/train/good/109.png MVTec/bottle/ground_truth/good/109_mask.png 0 - - Returns: - DataFrame: an output dataframe containing samples for the requested split (ie., train or test) - """ - samples_list = [(str(path),) + filename.parts[-3:] for filename in path.glob("**/*.png")] - if len(samples_list) == 0: - raise RuntimeError(f"Found 0 images in {path}") - - samples = pd.DataFrame(samples_list, columns=["path", "split", "label", "image_path"]) - samples = samples[samples.split != "ground_truth"] - - # Create mask_path column - samples["mask_path"] = ( - samples.path - + "/ground_truth/" - + samples.label - + "/" - + samples.image_path.str.rstrip("png").str.rstrip(".") - + "_mask.png" - ) - - # Modify image_path column by converting to absolute path - samples["image_path"] = samples.path + "/" + samples.split + "/" + samples.label + "/" + samples.image_path - - # Split the normal images in training set if test set doesn't - # contain any normal images. This is needed because AUC score - # cannot be computed based on 1-class - if sum((samples.split == "test") & (samples.label == "good")) == 0: - samples = split_normal_images_in_train_set(samples, split_ratio, seed) - - # Good images don't have mask - samples.loc[(samples.split == "test") & (samples.label == "good"), "mask_path"] = "" - - # Create label index for normal (0) and anomalous (1) images. - samples.loc[(samples.label == "good"), "label_index"] = 0 - samples.loc[(samples.label != "good"), "label_index"] = 1 - samples.label_index = samples.label_index.astype(int) - - if create_validation_set: - samples = create_validation_set_from_test_set(samples, seed=seed) - - # Get the data frame for the split. - if split is not None and split in ["train", "val", "test"]: - samples = samples[samples.split == split] - samples = samples.reset_index(drop=True) - - return samples - - -class MVTecDataset(VisionDataset): - """MVTec AD PyTorch Dataset.""" - - def __init__( - self, - root: Union[Path, str], - category: str, - pre_process: PreProcessor, - split: str, - task: str = "segmentation", - seed: Optional[int] = None, - create_validation_set: bool = False, - ) -> None: - """Mvtec AD Dataset class. - - Args: - root: Path to the MVTec AD dataset - category: Name of the MVTec AD category. - pre_process: List of pre_processing object containing albumentation compose. - split: 'train', 'val' or 'test' - task: ``classification`` or ``segmentation`` - seed: seed used for the random subset splitting - create_validation_set: Create a validation subset in addition to the train and test subsets - - Examples: - >>> from anomalib.data.mvtec import MVTecDataset - >>> from anomalib.data.transforms import PreProcessor - >>> pre_process = PreProcessor(image_size=256) - >>> dataset = MVTecDataset( - ... root='./datasets/MVTec', - ... category='leather', - ... pre_process=pre_process, - ... task="classification", - ... is_train=True, - ... ) - >>> dataset[0].keys() - dict_keys(['image']) - - >>> dataset.split = "test" - >>> dataset[0].keys() - dict_keys(['image', 'image_path', 'label']) - - >>> dataset.task = "segmentation" - >>> dataset.split = "train" - >>> dataset[0].keys() - dict_keys(['image']) - - >>> dataset.split = "test" - >>> dataset[0].keys() - dict_keys(['image_path', 'label', 'mask_path', 'image', 'mask']) - - >>> dataset[0]["image"].shape, dataset[0]["mask"].shape - (torch.Size([3, 256, 256]), torch.Size([256, 256])) - """ - super().__init__(root) - - if seed is None: - warnings.warn( - "seed is None." - " When seed is not set, images from the normal directory are split between training and test dir." - " This will lead to inconsistency between runs." - ) - - self.root = Path(root) if isinstance(root, str) else root - self.category: str = category - self.split = split - self.task = task - - self.pre_process = pre_process - - self.samples = make_mvtec_dataset( - path=self.root / category, - split=self.split, - seed=seed, - create_validation_set=create_validation_set, - ) - - def __len__(self) -> int: - """Get length of the dataset.""" - return len(self.samples) - - def __getitem__(self, index: int) -> Dict[str, Union[str, Tensor]]: - """Get dataset item for the index ``index``. - - Args: - index (int): Index to get the item. - - Returns: - Union[Dict[str, Tensor], Dict[str, Union[str, Tensor]]]: Dict of image tensor during training. - Otherwise, Dict containing image path, target path, image tensor, label and transformed bounding box. - """ - item: Dict[str, Union[str, Tensor]] = {} - - image_path = self.samples.image_path[index] - image = read_image(image_path) - - pre_processed = self.pre_process(image=image) - item = {"image": pre_processed["image"]} - - if self.split in ["val", "test"]: - label_index = self.samples.label_index[index] - - item["image_path"] = image_path - item["label"] = label_index - - if self.task == "segmentation": - mask_path = self.samples.mask_path[index] - - # Only Anomalous (1) images has masks in MVTec AD dataset. - # Therefore, create empty mask for Normal (0) images. - if label_index == 0: - mask = np.zeros(shape=image.shape[:2]) - else: - mask = cv2.imread(mask_path, flags=0) / 255.0 - - pre_processed = self.pre_process(image=image, mask=mask) - - item["mask_path"] = mask_path - item["image"] = pre_processed["image"] - item["mask"] = pre_processed["mask"] - - return item - - @DATAMODULE_REGISTRY -class MVTec(LightningDataModule): +class MVTec(AnomalibDataModule): """MVTec AD Lightning Data Module.""" def __init__( self, root: str, category: str, - # TODO: Remove default values. IAAALD-211 image_size: Optional[Union[int, Tuple[int, int]]] = None, train_batch_size: int = 32, test_batch_size: int = 32, @@ -295,6 +64,7 @@ def __init__( task: str = "segmentation", transform_config_train: Optional[Union[str, A.Compose]] = None, transform_config_val: Optional[Union[str, A.Compose]] = None, + split_ratio: float = 0.2, seed: Optional[int] = None, create_validation_set: bool = False, ) -> None: @@ -339,34 +109,24 @@ def __init__( >>> data["image"].shape, data["mask"].shape (torch.Size([32, 3, 256, 256]), torch.Size([32, 256, 256])) """ - super().__init__() - self.root = root if isinstance(root, Path) else Path(root) self.category = category - self.dataset_path = self.root / self.category - self.transform_config_train = transform_config_train - self.transform_config_val = transform_config_val - self.image_size = image_size - - if self.transform_config_train is not None and self.transform_config_val is None: - self.transform_config_val = self.transform_config_train - - self.pre_process_train = PreProcessor(config=self.transform_config_train, image_size=self.image_size) - self.pre_process_val = PreProcessor(config=self.transform_config_val, image_size=self.image_size) - - self.train_batch_size = train_batch_size - self.test_batch_size = test_batch_size - self.num_workers = num_workers + self.path = self.root / self.category self.create_validation_set = create_validation_set - self.task = task self.seed = seed - - self.train_data: Dataset - self.test_data: Dataset - if create_validation_set: - self.val_data: Dataset - self.inference_data: Dataset + self.split_ratio = split_ratio + + super().__init__( + task=task, + train_batch_size=train_batch_size, + test_batch_size=test_batch_size, + num_workers=num_workers, + transform_config_train=transform_config_train, + transform_config_val=transform_config_val, + image_size=image_size, + create_validation_set=create_validation_set, + ) def prepare_data(self) -> None: """Download the dataset if not available.""" @@ -393,68 +153,67 @@ def prepare_data(self) -> None: tar_file.extractall(self.root) logger.info("Cleaning the tar file") - (zip_filename).unlink() + zip_filename.unlink() - def setup(self, stage: Optional[str] = None) -> None: - """Setup train, validation and test data. + def _create_samples(self) -> DataFrame: + """Create MVTec AD samples by parsing the MVTec AD data file structure. - Args: - stage: Optional[str]: Train/Val/Test stages. (Default value = None) + The files are expected to follow the structure: + path/to/dataset/split/category/image_filename.png + path/to/dataset/ground_truth/category/mask_filename.png - """ - logger.info("Setting up train, validation, test and prediction datasets.") - if stage in (None, "fit"): - self.train_data = MVTecDataset( - root=self.root, - category=self.category, - pre_process=self.pre_process_train, - split="train", - task=self.task, - seed=self.seed, - create_validation_set=self.create_validation_set, - ) + This function creates a dataframe to store the parsed information based on the following format: + |---|---------------|-------|---------|---------------|---------------------------------------|-------------| + | | path | split | label | image_path | mask_path | label_index | + |---|---------------|-------|---------|---------------|---------------------------------------|-------------| + | 0 | datasets/name | test | defect | filename.png | ground_truth/defect/filename_mask.png | 1 | + |---|---------------|-------|---------|---------------|---------------------------------------|-------------| - if self.create_validation_set: - self.val_data = MVTecDataset( - root=self.root, - category=self.category, - pre_process=self.pre_process_val, - split="val", - task=self.task, - seed=self.seed, - create_validation_set=self.create_validation_set, + Returns: + DataFrame: an output dataframe containing the samples of the dataset. + """ + if self.seed is None: + warnings.warn( + "seed is None." + " When seed is not set, images from the normal directory are split between training and test dir." + " This will lead to inconsistency between runs." ) - self.test_data = MVTecDataset( - root=self.root, - category=self.category, - pre_process=self.pre_process_val, - split="test", - task=self.task, - seed=self.seed, - create_validation_set=self.create_validation_set, + samples_list = [(str(self.path),) + filename.parts[-3:] for filename in self.path.glob("**/*.png")] + if len(samples_list) == 0: + raise RuntimeError(f"Found 0 images in {self.path}") + + samples = pd.DataFrame(samples_list, columns=["path", "split", "label", "image_path"]) + samples = samples[samples.split != "ground_truth"] + + # Create mask_path column + samples["mask_path"] = ( + samples.path + + "/ground_truth/" + + samples.label + + "/" + + samples.image_path.str.rstrip("png").str.rstrip(".") + + "_mask.png" ) - if stage == "predict": - self.inference_data = InferenceDataset( - path=self.root, image_size=self.image_size, transform_config=self.transform_config_val - ) + # Modify image_path column by converting to absolute path + samples["image_path"] = samples.path + "/" + samples.split + "/" + samples.label + "/" + samples.image_path - def train_dataloader(self) -> TRAIN_DATALOADERS: - """Get train dataloader.""" - return DataLoader(self.train_data, shuffle=True, batch_size=self.train_batch_size, num_workers=self.num_workers) + # Split the normal images in training set if test set doesn't + # contain any normal images. This is needed because AUC score + # cannot be computed based on 1-class + if sum((samples.split == "test") & (samples.label == "good")) == 0: + samples = split_normal_images_in_train_set(samples, self.split_ratio, self.seed) - def val_dataloader(self) -> EVAL_DATALOADERS: - """Get validation dataloader.""" - dataset = self.val_data if self.create_validation_set else self.test_data - return DataLoader(dataset=dataset, shuffle=False, batch_size=self.test_batch_size, num_workers=self.num_workers) + # Good images don't have mask + samples.loc[(samples.split == "test") & (samples.label == "good"), "mask_path"] = "" - def test_dataloader(self) -> EVAL_DATALOADERS: - """Get test dataloader.""" - return DataLoader(self.test_data, shuffle=False, batch_size=self.test_batch_size, num_workers=self.num_workers) + # Create label index for normal (0) and anomalous (1) images. + samples.loc[(samples.label == "good"), "label_index"] = 0 + samples.loc[(samples.label != "good"), "label_index"] = 1 + samples.label_index = samples.label_index.astype(int) - def predict_dataloader(self) -> EVAL_DATALOADERS: - """Get predict dataloader.""" - return DataLoader( - self.inference_data, shuffle=False, batch_size=self.test_batch_size, num_workers=self.num_workers - ) + if self.create_validation_set: + samples = create_validation_set_from_test_set(samples, seed=self.seed) + + return samples diff --git a/anomalib/utils/metrics/adaptive_threshold.py b/anomalib/utils/metrics/adaptive_threshold.py index fd112433f1..868c6e2ad6 100644 --- a/anomalib/utils/metrics/adaptive_threshold.py +++ b/anomalib/utils/metrics/adaptive_threshold.py @@ -3,6 +3,8 @@ # Copyright (C) 2022 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +import warnings + import torch from torchmetrics import PrecisionRecallCurve @@ -33,6 +35,14 @@ def compute(self) -> torch.Tensor: recall: torch.Tensor thresholds: torch.Tensor + if not any(1 in batch for batch in self.target): + warnings.warn( + "The validation set does not contain any anomalous images. As a result, the adaptive threshold will " + "take the value of the highest anomaly score observed in the normal validation images, which may lead " + "to poor predictions. For a more reliable adaptive threshold computation, please add some anomalous " + "images to the validation set." + ) + precision, recall, thresholds = super().compute() f1_score = (2 * precision * recall) / (precision + recall + 1e-10) if thresholds.dim() == 0: diff --git a/tools/train.py b/tools/train.py index 0e5daa3b10..33952a7e20 100644 --- a/tools/train.py +++ b/tools/train.py @@ -63,8 +63,11 @@ def train(): load_model_callback = LoadModelCallback(weights_path=trainer.checkpoint_callback.best_model_path) trainer.callbacks.insert(0, load_model_callback) - logger.info("Testing the model.") - trainer.test(model=model, datamodule=datamodule) + if datamodule.contains_anomalous_images("test"): + logger.info("Testing the model.") + trainer.test(model=model, datamodule=datamodule) + else: + logger.info("No anomalous images found in dataset. Skipping test stage.") if __name__ == "__main__":