Skip to content

Commit

Permalink
Add support for Detection task type (#732)
Browse files Browse the repository at this point in the history
* add basic support for detection task

* use enum for task type

* formatting

* small bugfix

* add unit tests for bounding box conversion

* update error message

* use as_tensor

* typing and docstring

* explicit keyword arguments

* simplify bbox handling in video dataset

* docstring consistency

* add missing licenses

* add whitespace for readability

* add missing license

* Update anomalib/data/utils/boxes.py

Co-authored-by: Samet Akcay <[email protected]>

* Revert "Update anomalib/data/utils/boxes.py"

This reverts commit cec6138.

* add test case for custom collate function

* docstring

* add integration tests for detection dataloading

* extend and clean up datamodules tests

* add detection task type to visualizer tests

* only show pred_boxes during inference

* add detection support for torch inference

* add detection support for openvino inference

* test inference for all task types

* pylint

Co-authored-by: Samet Akcay <[email protected]>
  • Loading branch information
djdameln and samet-akcay authored Dec 6, 2022
1 parent cb06714 commit 045d77f
Show file tree
Hide file tree
Showing 35 changed files with 948 additions and 264 deletions.
2 changes: 2 additions & 0 deletions anomalib/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .folder import Folder
from .inference import InferenceDataset
from .mvtec import MVTec
from .task_type import TaskType
from .ucsd_ped import UCSDped

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -131,4 +132,5 @@ def get_datamodule(config: Union[DictConfig, ListConfig]) -> AnomalibDataModule:
"MVTec",
"Avenue",
"UCSDped",
"TaskType",
]
9 changes: 5 additions & 4 deletions anomalib/data/avenue.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from torch import Tensor

from anomalib.data.base import AnomalibDataModule, VideoAnomalibDataset
from anomalib.data.task_type import TaskType
from anomalib.data.utils import DownloadProgressBar, Split, ValSplitMode, hash_check
from anomalib.data.utils.video import ClipsIndexer
from anomalib.pre_processing import PreProcessor
Expand Down Expand Up @@ -124,7 +125,7 @@ class AvenueDataset(VideoAnomalibDataset):
"""Avenue Dataset class.
Args:
task (str): Task type, either 'classification' or 'segmentation'
task (TaskType): Task type, 'classification', 'detection' or 'segmentation'
root (str): Path to the root of the dataset
gt_dir (str): Path to the ground truth files
pre_process (PreProcessor): Pre-processor object
Expand All @@ -135,7 +136,7 @@ class AvenueDataset(VideoAnomalibDataset):

def __init__(
self,
task: str,
task: TaskType,
root: Union[Path, str],
gt_dir: str,
pre_process: PreProcessor,
Expand Down Expand Up @@ -163,7 +164,7 @@ class Avenue(AnomalibDataModule):
gt_dir (str): Path to the ground truth files
clip_length_in_frames (int, optional): Number of video frames in each clip.
frames_between_clips (int, optional): Number of frames between each consecutive video clip.
task (str): Task type, either 'classification' or 'segmentation'
task TaskType): Task type, 'classification', 'detection' or 'segmentation'
image_size (Optional[Union[int, Tuple[int, int]]], optional): Size of the input image.
Defaults to None.
train_batch_size (int, optional): Training batch size. Defaults to 32.
Expand All @@ -184,7 +185,7 @@ def __init__(
gt_dir: str,
clip_length_in_frames: int = 1,
frames_between_clips: int = 1,
task: str = "segmentation",
task: TaskType = TaskType.SEGMENTATION,
image_size: Optional[Union[int, Tuple[int, int]]] = None,
train_batch_size: int = 32,
eval_batch_size: int = 32,
Expand Down
47 changes: 42 additions & 5 deletions anomalib/data/base/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,42 @@

import logging
from abc import ABC
from typing import Optional
from typing import Any, Dict, List, Optional

from pandas import DataFrame
from pytorch_lightning import LightningDataModule
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, default_collate

from anomalib.data.base.dataset import AnomalibDataset
from anomalib.data.utils import ValSplitMode, random_split

logger = logging.getLogger(__name__)


def collate_fn(batch: List) -> Dict[str, Any]:
"""Custom collate function that collates bounding boxes as lists.
Bounding boxes are collated as a list of tensors, while the default collate function is used for all other entries.
Args:
batch (List): list of items in the batch where len(batch) is equal to the batch size.
Returns:
Dict[str, Any]: Dictionary containing the collated batch information.
"""
elem = batch[0] # sample an element from the batch to check the type.
out_dict = {}
if isinstance(elem, dict):
if "boxes" in elem.keys():
# collate boxes as list
out_dict["boxes"] = [item.pop("boxes") for item in batch]
# collate other data normally
out_dict.update({key: default_collate([item[key] for item in batch]) for key in elem})
return out_dict
return default_collate(batch)


class AnomalibDataModule(LightningDataModule, ABC):
"""Base Anomalib data module.
Expand Down Expand Up @@ -101,12 +124,26 @@ def is_setup(self):

def train_dataloader(self) -> TRAIN_DATALOADERS:
"""Get train dataloader."""
return DataLoader(self.train_data, shuffle=True, batch_size=self.train_batch_size, num_workers=self.num_workers)
return DataLoader(
dataset=self.train_data, shuffle=True, batch_size=self.train_batch_size, num_workers=self.num_workers
)

def val_dataloader(self) -> EVAL_DATALOADERS:
"""Get validation dataloader."""
return DataLoader(self.val_data, shuffle=False, batch_size=self.eval_batch_size, num_workers=self.num_workers)
return DataLoader(
dataset=self.val_data,
shuffle=False,
batch_size=self.eval_batch_size,
num_workers=self.num_workers,
collate_fn=collate_fn,
)

def test_dataloader(self) -> EVAL_DATALOADERS:
"""Get test dataloader."""
return DataLoader(self.test_data, shuffle=False, batch_size=self.eval_batch_size, num_workers=self.num_workers)
return DataLoader(
dataset=self.test_data,
shuffle=False,
batch_size=self.eval_batch_size,
num_workers=self.num_workers,
collate_fn=collate_fn,
)
22 changes: 14 additions & 8 deletions anomalib/data/base/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,16 @@
from torch import Tensor
from torch.utils.data import Dataset

from anomalib.data.utils import read_image
from anomalib.data.task_type import TaskType
from anomalib.data.utils import masks_to_boxes, read_image
from anomalib.pre_processing import PreProcessor

_EXPECTED_COLS_CLASSIFICATION = ["image_path", "split"]
_EXPECTED_COLS_SEGMENTATION = _EXPECTED_COLS_CLASSIFICATION + ["mask_path"]
_EXPECTED_COLS_PERTASK = {
"classification": _EXPECTED_COLS_CLASSIFICATION,
"segmentation": _EXPECTED_COLS_SEGMENTATION,
"detection": _EXPECTED_COLS_SEGMENTATION,
}

logger = logging.getLogger(__name__)
Expand All @@ -34,7 +36,7 @@
class AnomalibDataset(Dataset, ABC):
"""Anomalib dataset."""

def __init__(self, task: str, pre_process: PreProcessor):
def __init__(self, task: TaskType, pre_process: PreProcessor):
super().__init__()
self.task = task
self.pre_process = pre_process
Expand Down Expand Up @@ -107,16 +109,16 @@ def __getitem__(self, index: int) -> Dict[str, Union[str, Tensor]]:
"""

image_path = self._samples.iloc[index].image_path
image = read_image(image_path)
mask_path = self._samples.iloc[index].mask_path
label_index = self._samples.iloc[index].label_index

image = read_image(image_path)
item = dict(image_path=image_path, label=label_index)

if self.task == "classification":
if self.task == TaskType.CLASSIFICATION:
pre_processed = self.pre_process(image=image)
elif self.task == "segmentation":
mask_path = self._samples.iloc[index].mask_path

item["image"] = pre_processed["image"]
elif self.task in [TaskType.DETECTION, TaskType.SEGMENTATION]:
# Only Anomalous (1) images have masks in anomaly datasets
# Therefore, create empty mask for Normal (0) images.
if label_index == 0:
Expand All @@ -126,11 +128,15 @@ def __getitem__(self, index: int) -> Dict[str, Union[str, Tensor]]:

pre_processed = self.pre_process(image=image, mask=mask)

item["image"] = pre_processed["image"]
item["mask_path"] = mask_path
item["mask"] = pre_processed["mask"]

if self.task == TaskType.DETECTION:
# create boxes from masks for detection task
item["boxes"] = masks_to_boxes(item["mask"])[0]
else:
raise ValueError(f"Unknown task type: {self.task}")
item["image"] = pre_processed["image"]

return item

Expand Down
11 changes: 9 additions & 2 deletions anomalib/data/base/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from torch import Tensor

from anomalib.data.base.dataset import AnomalibDataset
from anomalib.data.task_type import TaskType
from anomalib.data.utils import masks_to_boxes
from anomalib.data.utils.video import ClipsIndexer
from anomalib.pre_processing import PreProcessor

Expand All @@ -21,7 +23,9 @@ class VideoAnomalibDataset(AnomalibDataset, ABC):
frames_between_clips (int): Number of frames between each consecutive video clip.
"""

def __init__(self, task: str, pre_process: PreProcessor, clip_length_in_frames: int, frames_between_clips: int):
def __init__(
self, task: TaskType, pre_process: PreProcessor, clip_length_in_frames: int, frames_between_clips: int
):
super().__init__(task, pre_process)

self.clip_length_in_frames = clip_length_in_frames
Expand Down Expand Up @@ -74,9 +78,12 @@ def __getitem__(self, index: int) -> Dict[str, Union[str, Tensor]]:
self.pre_process(image=frame.numpy(), mask=mask) for frame, mask in zip(item["image"], item["mask"])
]
item["image"] = torch.stack([item["image"] for item in processed_frames]).squeeze(0)
mask = item["mask"]
mask = torch.as_tensor(item["mask"])
item["mask"] = torch.stack([item["mask"] for item in processed_frames]).squeeze(0)
item["label"] = Tensor([1 in frame for frame in mask]).int().squeeze(0)
if self.task == TaskType.DETECTION:
item["boxes"] = masks_to_boxes(item["mask"])
item["boxes"] = item["boxes"][0] if len(item["boxes"]) == 1 else item["boxes"]
else:
item["image"] = torch.stack(
[self.pre_process(image=frame.numpy())["image"] for frame in item["image"]]
Expand Down
9 changes: 5 additions & 4 deletions anomalib/data/btech.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from tqdm import tqdm

from anomalib.data.base import AnomalibDataModule, AnomalibDataset
from anomalib.data.task_type import TaskType
from anomalib.data.utils import DownloadProgressBar, Split, ValSplitMode, hash_check
from anomalib.pre_processing import PreProcessor

Expand Down Expand Up @@ -114,7 +115,7 @@ def __init__(
category: str,
pre_process: PreProcessor,
split: Optional[Union[Split, str]] = None,
task: str = "segmentation",
task: TaskType = TaskType.SEGMENTATION,
) -> None:
"""Btech Dataset class.
Expand All @@ -123,7 +124,7 @@ def __init__(
category: Name of the BTech category.
pre_process: List of pre_processing object containing albumentation compose.
split: 'train', 'val' or 'test'
task: ``classification`` or ``segmentation``
task: ``classification``, ``detection`` or ``segmentation``
create_validation_set: Create a validation subset in addition to the train and test subsets
Examples:
Expand Down Expand Up @@ -177,7 +178,7 @@ def __init__(
train_batch_size: int = 32,
eval_batch_size: int = 32,
num_workers: int = 8,
task: str = "segmentation",
task: TaskType = TaskType.SEGMENTATION,
transform_config_train: Optional[Union[str, A.Compose]] = None,
transform_config_eval: Optional[Union[str, A.Compose]] = None,
val_split_mode: ValSplitMode = ValSplitMode.SAME_AS_TEST,
Expand All @@ -193,7 +194,7 @@ def __init__(
train_batch_size: Training batch size.
test_batch_size: Testing batch size.
num_workers: Number of workers.
task: ``classification`` or ``segmentation``
task: ``classification``, ``detection`` or ``segmentation``
transform_config_train: Config for pre-processing during training.
transform_config_val: Config for pre-processing during validation.
create_validation_set: Create a validation subset in addition to the train and test subsets
Expand Down
11 changes: 6 additions & 5 deletions anomalib/data/folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from torchvision.datasets.folder import IMG_EXTENSIONS

from anomalib.data.base import AnomalibDataModule, AnomalibDataset
from anomalib.data.task_type import TaskType
from anomalib.data.utils import Split, ValSplitMode, random_split
from anomalib.pre_processing.pre_process import PreProcessor

Expand Down Expand Up @@ -147,7 +148,7 @@ class FolderDataset(AnomalibDataset):
"""Folder dataset.
Args:
task (str): Task type. (classification or segmentation).
task (TaskType): Task type. (``classification``, ``detection`` or ``segmentation``).
pre_process (PreProcessor): Image Pre-processor to apply transform.
split (Optional[Union[Split, str]]): Fixed subset split that follows from folder structure on file system.
Choose from [Split.FULL, Split.TRAIN, Split.TEST]
Expand All @@ -171,7 +172,7 @@ class FolderDataset(AnomalibDataset):

def __init__(
self,
task: str,
task: TaskType,
pre_process: PreProcessor,
root: Union[str, Path],
normal_dir: Union[str, Path],
Expand Down Expand Up @@ -228,8 +229,8 @@ class Folder(AnomalibDataModule):
train_batch_size (int, optional): Training batch size. Defaults to 32.
test_batch_size (int, optional): Test batch size. Defaults to 32.
num_workers (int, optional): Number of workers. Defaults to 8.
task (str, optional): Task type. Could be either classification or segmentation.
Defaults to "classification".
task (TaskType, optional): Task type. Could be ``classification``, ``detection`` or ``segmentation``.
Defaults to segmentation.
transform_config_train (Optional[Union[str, A.Compose]], optional): Config for pre-processing
during training.
Defaults to None.
Expand All @@ -254,7 +255,7 @@ def __init__(
train_batch_size: int = 32,
eval_batch_size: int = 32,
num_workers: int = 8,
task: str = "segmentation",
task: TaskType = TaskType.SEGMENTATION,
transform_config_train: Optional[Union[str, A.Compose]] = None,
transform_config_eval: Optional[Union[str, A.Compose]] = None,
val_split_mode: ValSplitMode = ValSplitMode.FROM_TEST,
Expand Down
7 changes: 4 additions & 3 deletions anomalib/data/mvtec.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from pandas import DataFrame

from anomalib.data.base import AnomalibDataModule, AnomalibDataset
from anomalib.data.task_type import TaskType
from anomalib.data.utils import DownloadProgressBar, Split, ValSplitMode, hash_check
from anomalib.pre_processing import PreProcessor

Expand Down Expand Up @@ -123,7 +124,7 @@ class MVTecDataset(AnomalibDataset):
"""MVTec dataset class.
Args:
task (str): Task type, either 'classification' or 'segmentation'
task (TaskType): Task type, ``classification``, ``detection`` or ``segmentation``
pre_process (PreProcessor): Pre-processor object
split (Optional[Union[Split, str]]): Split of the dataset, usually Split.TRAIN or Split.TEST
root (str): Path to the root of the dataset
Expand All @@ -132,7 +133,7 @@ class MVTecDataset(AnomalibDataset):

def __init__(
self,
task: str,
task: TaskType,
pre_process: PreProcessor,
root: str,
category: str,
Expand All @@ -158,7 +159,7 @@ def __init__(
train_batch_size: int = 32,
eval_batch_size: int = 32,
num_workers: int = 8,
task: str = "segmentation",
task: TaskType = TaskType.SEGMENTATION,
transform_config_train: Optional[Union[str, A.Compose]] = None,
transform_config_eval: Optional[Union[str, A.Compose]] = None,
val_split_mode: ValSplitMode = ValSplitMode.SAME_AS_TEST,
Expand Down
14 changes: 14 additions & 0 deletions anomalib/data/task_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""Task type enum."""

# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from enum import Enum


class TaskType(str, Enum):
"""Task type used when generating predictions on the dataset."""

CLASSIFICATION = "classification"
DETECTION = "detection"
SEGMENTATION = "segmentation"
Loading

0 comments on commit 045d77f

Please sign in to comment.