Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor data modules #558

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions anomalib/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from typing import Union

from omegaconf import DictConfig, ListConfig
from pytorch_lightning import LightningDataModule

from anomalib.data.base import AnomalibDataModule

from .btech import BTech
from .folder import Folder
Expand All @@ -17,7 +18,7 @@
logger = logging.getLogger(__name__)


def get_datamodule(config: Union[DictConfig, ListConfig]) -> LightningDataModule:
def get_datamodule(config: Union[DictConfig, ListConfig]) -> AnomalibDataModule:
"""Get Anomaly Datamodule.

Args:
Expand All @@ -28,7 +29,7 @@ def get_datamodule(config: Union[DictConfig, ListConfig]) -> LightningDataModule
"""
logger.info("Loading the datamodule")

datamodule: LightningDataModule
datamodule: AnomalibDataModule

if config.dataset.format.lower() == "mvtec":
datamodule = MVTec(
Expand Down
166 changes: 166 additions & 0 deletions anomalib/data/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
"""Anomalib dataset and datamodule base classes."""

# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import logging
from abc import ABC, abstractmethod
from typing import Dict, Optional, Tuple, Union

import albumentations as A
import cv2
import numpy as np
from pandas import DataFrame
from pytorch_lightning import LightningDataModule
from torch import Tensor
from torch.utils.data import Dataset

from anomalib.data.utils import read_image
from anomalib.pre_processing import PreProcessor

logger = logging.getLogger(__name__)


class AnomalibDataset(Dataset):
"""Anomalib dataset."""

def __init__(self, samples: DataFrame, task: str, split: str, pre_process: PreProcessor):
super().__init__()
self.samples = samples
self.task = task
self.split = split
self.pre_process = pre_process

def contains_anomalous_images(self):
"""Check if the dataset contains any anomalous images."""
return "anomalous" in list(self.samples.label)

def __len__(self) -> int:
"""Get length of the dataset."""
return len(self.samples)

def __getitem__(self, index: int) -> Dict[str, Union[str, Tensor]]:
"""Get dataset item for the index ``index``.

Args:
index (int): Index to get the item.

Returns:
Union[Dict[str, Tensor], Dict[str, Union[str, Tensor]]]: Dict of image tensor during training.
Otherwise, Dict containing image path, target path, image tensor, label and transformed bounding box.
"""
image_path = self.samples.image_path[index]
djdameln marked this conversation as resolved.
Show resolved Hide resolved
image = read_image(image_path)

pre_processed = self.pre_process(image=image)
item = {"image": pre_processed["image"]}

if self.split in ["val", "test"]:
label_index = self.samples.label_index[index]

item["image_path"] = image_path
item["label"] = label_index

if self.task == "segmentation":
mask_path = self.samples.mask_path[index]

# Only Anomalous (1) images have masks in anomaly datasets
# Therefore, create empty mask for Normal (0) images.
if label_index == 0:
mask = np.zeros(shape=image.shape[:2])
else:
mask = cv2.imread(mask_path, flags=0) / 255.0

pre_processed = self.pre_process(image=image, mask=mask)
djdameln marked this conversation as resolved.
Show resolved Hide resolved

item["mask_path"] = mask_path
item["image"] = pre_processed["image"]
item["mask"] = pre_processed["mask"]

return item


class AnomalibDataModule(LightningDataModule, ABC):
"""Base Anomalib data module."""

def __init__(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[minor] val and test are assumed to be each other in a confusing way

(1) X_batch_size with X \in {train, test}, and test_batch_size is used both for test and val DataLoader

(2) transform_config_Y with Y in {train, val}, and transform_config_val is used both for test and val AnomalibDataset

i.e.

batch size: config is available for test and assumed for test and val
transform: config is available for val and assumed for test and val

self,
task: str,
transform_config_train: Optional[Union[str, A.Compose]] = None,
transform_config_val: Optional[Union[str, A.Compose]] = None,
image_size: Optional[Union[int, Tuple[int, int]]] = None,
create_validation_set: bool = False,
):
super().__init__()
self.task = task
self.create_validation_set = create_validation_set

if transform_config_train is not None and transform_config_val is None:
Copy link
Contributor

@jpcbertoldo jpcbertoldo Sep 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[minor] Is there a reason for assuming this?

It could make sense that transform_config_train could have "weak" data augmentations (e.g. tiny brightness changes) but that should not be repeated in the validation set.


if this is to be kept, a warning wouldn't be harmful :)

transform_config_val = transform_config_train
self.pre_process_train = PreProcessor(config=transform_config_train, image_size=image_size)
self.pre_process_val = PreProcessor(config=transform_config_val, image_size=image_size)

self.train_data: Optional[AnomalibDataset] = None
self.val_data: Optional[AnomalibDataset] = None
self.test_data: Optional[AnomalibDataset] = None

@abstractmethod
def _create_samples(self) -> DataFrame:
"""This method should be implemented in the subclass.

This method should return a dataframe that contains the information needed by the dataloader to load each of
the dataset items into memory. The dataframe must at least contain the following columns:
split - The subset to which the dataset item is assigned.
image_path - Path to file system location where the image is stored.
label_index - Index of the anomaly label, typically 0 for "normal" and 1 for "anomalous".

Additionally, when the task type is segmentation, the dataframe must have the mask_path column, which contains
the path the ground truth masks (for the anomalous images only).

Example of a dataframe returned by calling this method from a concrete class:
|---|-------------------|-----------|-------------|------------------|-------|
| | image_path | label | label_index | mask_path | split |
|---|-------------------|-----------|-------------|------------------|-------|
| 0 | path/to/image.png | anomalous | 0 | path/to/mask.png | train |
|---|-------------------|-----------|-------------|------------------|-------|
"""
ashwinvaidya17 marked this conversation as resolved.
Show resolved Hide resolved
raise NotImplementedError

def setup(self, stage: Optional[str] = None) -> None:
"""Setup train, validation and test data.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
If `stage` is `None`, all splits are setup.

Args:
stage: Optional[str]: Train/Val/Test stages. (Default value = None)

"""
samples = self._create_samples()

logger.info("Setting up train, validation, test and prediction datasets.")
if stage in (None, "fit"):
train_samples = samples[samples.split == "train"]
train_samples = train_samples.reset_index(drop=True)
self.train_data = AnomalibDataset(
samples=train_samples,
split="train",
task=self.task,
pre_process=self.pre_process_train,
)

if self.create_validation_set:
val_samples = samples[samples.split == "val"]
val_samples = val_samples.reset_index(drop=True)
self.val_data = AnomalibDataset(
samples=val_samples,
split="val",
task=self.task,
pre_process=self.pre_process_val,
)

test_samples = samples[samples.split == "test"]
djdameln marked this conversation as resolved.
Show resolved Hide resolved
test_samples = test_samples.reset_index(drop=True)
self.test_data = AnomalibDataset(
samples=test_samples,
split="test",
task=self.task,
pre_process=self.pre_process_val,
)
Loading