diff --git a/src/anomalib/dataclasses/__init__.py b/src/anomalib/dataclasses/__init__.py index 2054dea22c..e6d3112a92 100644 --- a/src/anomalib/dataclasses/__init__.py +++ b/src/anomalib/dataclasses/__init__.py @@ -1,4 +1,35 @@ -"""Anomalib dataclasses.""" +"""Anomalib dataclasses. + +This module provides a collection of dataclasses used throughout the Anomalib library +for representing and managing various types of data related to anomaly detection tasks. + +The dataclasses are organized into two main categories: +1. Numpy-based dataclasses for handling numpy array data. +2. Torch-based dataclasses for handling PyTorch tensor data. + +Key components: + +Numpy Dataclasses: + ``NumpyImageItem``: Represents a single image item as numpy arrays. + ``NumpyImageBatch``: Represents a batch of image data as numpy arrays. + ``NumpyVideoItem``: Represents a single video item as numpy arrays. + ``NumpyVideoBatch``: Represents a batch of video data as numpy arrays. + +Torch Dataclasses: + ``Batch``: Base class for torch-based batch data. + ``DatasetItem``: Base class for torch-based dataset items. + ``DepthItem``: Represents a single depth data item. + ``DepthBatch``: Represents a batch of depth data. + ``ImageItem``: Represents a single image item as torch tensors. + ``ImageBatch``: Represents a batch of image data as torch tensors. + ``VideoItem``: Represents a single video item as torch tensors. + ``VideoBatch``: Represents a batch of video data as torch tensors. + ``InferenceBatch``: Specialized batch class for inference results. + +These dataclasses provide a structured way to handle various types of data +in anomaly detection tasks, ensuring type consistency and easy data manipulation +across different components of the Anomalib library. +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 diff --git a/src/anomalib/dataclasses/generic.py b/src/anomalib/dataclasses/generic.py index cf5258dc01..7d6ea72777 100644 --- a/src/anomalib/dataclasses/generic.py +++ b/src/anomalib/dataclasses/generic.py @@ -1,4 +1,10 @@ -"""Generic dataclasses that can be implemented for different data types.""" +"""Generic dataclasses that can be implemented for different data types. + +This module provides a set of generic dataclasses and mixins that can be used +to define and validate various types of data fields used in Anomalib. +The dataclasses are designed to be flexible and extensible, allowing for easy +customization and validation of input and output data. +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -24,13 +30,19 @@ Value = TypeVar("Value") -class FieldDescriptor( - Generic[Value], -): +class FieldDescriptor(Generic[Value]): """Descriptor for Anomalib's dataclass fields. - Using a descriptor ensures that the values of dataclass fields can be validated before being set. - This allows validation of the input data not only when it is first set, but also when it is updated. + Using a descriptor ensures that the values of dataclass fields can be + validated before being set. This allows validation of the input data not + only when it is first set, but also when it is updated. + + Attributes: + validator_name (str | None): The name of the validator method to be + called when setting the value. + Defaults to ``None``. + default (Value | None): The default value for the field. + Defaults to ``None``. """ def __init__(self, validator_name: str | None = None, default: Value | None = None) -> None: @@ -56,7 +68,7 @@ def __get__(self, instance: Instance | None, owner: type[Instance]) -> Value | N raise AttributeError(msg) return instance.__dict__[self.name] - def __set__(self, instance: Instance, value: Value) -> None: + def __set__(self, instance: object, value: Value) -> None: """Set the value of the descriptor. First calls the validator method if available, then sets the value of the attribute. @@ -82,7 +94,39 @@ def is_optional(self, owner: type[Instance]) -> bool: @dataclass class _InputFields(Generic[T, ImageT, MaskT, PathT], ABC): - """Generic dataclass that defines the standard input fields.""" + """Generic dataclass that defines the standard input fields for Anomalib. + + This abstract base class provides a structure for input data used in Anomalib, + a library for anomaly detection in images and videos. It defines common fields + used across various anomaly detection tasks and data types in Anomalib. + + Subclasses must implement the abstract validation methods to define the + specific validation logic for each field based on the requirements of different + Anomalib models and data processing pipelines. + + Examples: + Assuming a concrete implementation `DummyInput`: + + >>> class DummyInput(_InputFields[int, Image, Mask, str]): + ... # Implement actual validation + + >>> # Create an input instance + >>> input_item = DummyInput( + ... image=torch.rand(3, 224, 224), + ... gt_label=1, + ... gt_mask=torch.rand(224, 224) > 0.5, + ... mask_path="path/to/mask.png" + ... ) + + >>> # Access fields + >>> image = input_item.image + >>> label = input_item.gt_label + + Note: + This is an abstract base class and is not intended to be instantiated + directly. Concrete subclasses should implement all required validation + methods. + """ image: FieldDescriptor[ImageT] = FieldDescriptor(validator_name="_validate_image") gt_label: FieldDescriptor[T | None] = FieldDescriptor(validator_name="_validate_gt_label") @@ -111,11 +155,40 @@ def _validate_gt_label(self, gt_label: T) -> T | None: @dataclass -class _ImageInputFields( - Generic[PathT], - ABC, -): - """Generic dataclass that defines the image input fields.""" +class _ImageInputFields(Generic[PathT], ABC): + """Generic dataclass for image-specific input fields in Anomalib. + + This class extends standard input fields with an ``image_path`` attribute for + image-based anomaly detection tasks. It allows Anomalib to work efficiently + with disk-stored image datasets, facilitating custom data loading strategies. + + The ``image_path`` field uses a ``FieldDescriptor`` with a validation method. + Subclasses must implement ``_validate_image_path`` to ensure path validity + according to specific Anomalib model or dataset requirements. + + This class is designed to complement ``_InputFields`` for comprehensive + image-based anomaly detection input in Anomalib. + + Examples: + Assuming a concrete implementation ``DummyImageInput``: + >>> class DummyImageInput(_ImageInputFields): + ... def _validate_image_path(self, image_path): + ... return image_path # Implement actual validation + ... # Implement other required methods + + >>> # Create an image input instance + >>> image_input = DummyImageInput( + ... image_path="path/to/image.jpg" + ... ) + + >>> # Access image-specific field + >>> path = image_input.image_path + + Note: + This is an abstract base class and is not intended to be instantiated + directly. Concrete subclasses should implement all required validation + methods. + """ image_path: FieldDescriptor[PathT | None] = FieldDescriptor(validator_name="_validate_image_path") @@ -126,11 +199,49 @@ def _validate_image_path(self, image_path: PathT) -> PathT | None: @dataclass -class _VideoInputFields( - Generic[T, ImageT, MaskT, PathT], - ABC, -): - """Generic dataclass that defines the video input fields.""" +class _VideoInputFields(Generic[T, ImageT, MaskT, PathT], ABC): + """Generic dataclass that defines the video input fields for Anomalib. + + This class extends standard input fields with attributes specific to video-based + anomaly detection tasks. It includes fields for original images, video paths, + target frames, frame sequences, and last frames. + + Each field uses a ``FieldDescriptor`` with a corresponding validation method. + Subclasses must implement these abstract validation methods to ensure data + consistency with Anomalib's video processing requirements. + + This class is designed to work alongside other input field classes to provide + comprehensive support for video-based anomaly detection in Anomalib. + + Examples: + Assuming a concrete implementation ``DummyVideoInput``: + + >>> class DummyVideoInput(_VideoInputFields): + ... def _validate_original_image(self, original_image): + ... return original_image # Implement actual validation + ... # Implement other required methods + + >>> # Create a video input instance + >>> video_input = DummyVideoInput( + ... original_image=torch.rand(3, 224, 224), + ... video_path="path/to/video.mp4", + ... target_frame=10, + ... frames=torch.rand(3, 224, 224), + ... last_frame=torch.rand(3, 224, 224) + ... ) + + >>> # Access video-specific fields + >>> original_image = video_input.original_image + >>> path = video_input.video_path + >>> target_frame = video_input.target_frame + >>> frames = video_input.frames + >>> last_frame = video_input.last_frame + + Note: + This is an abstract base class and is not intended to be instantiated + directly. Concrete subclasses should implement all required validation + methods. + """ original_image: FieldDescriptor[ImageT | None] = FieldDescriptor(validator_name="_validate_original_image") video_path: FieldDescriptor[PathT | None] = FieldDescriptor(validator_name="_validate_video_path") @@ -165,12 +276,45 @@ def _validate_last_frame(self, last_frame: T) -> T | None: @dataclass -class _DepthInputFields( - Generic[T, PathT], - _ImageInputFields[PathT], - ABC, -): - """Generic dataclass that defines the depth input fields.""" +class _DepthInputFields(Generic[T, PathT], _ImageInputFields[PathT], ABC): + """Generic dataclass that defines the depth input fields for Anomalib. + + This class extends the standard input fields with a ``depth_map`` and + ``depth_path`` attribute for depth-based anomaly detection tasks. It allows + Anomalib to work efficiently with depth-based anomaly detection tasks, + facilitating custom data loading strategies. + + The ``depth_map`` and ``depth_path`` fields use a ``FieldDescriptor`` with + corresponding validation methods. Subclasses must implement these abstract + validation methods to ensure data consistency with Anomalib's depth processing + requirements. + + Examples: + Assuming a concrete implementation ``DummyDepthInput``: + + >>> class DummyDepthInput(_DepthInputFields): + ... def _validate_depth_map(self, depth_map): + ... return depth_map # Implement actual validation + ... def _validate_depth_path(self, depth_path): + ... return depth_path # Implement actual validation + ... # Implement other required methods + + >>> # Create a depth input instance + >>> depth_input = DummyDepthInput( + ... image_path="path/to/image.jpg", + ... depth_map=torch.rand(224, 224), + ... depth_path="path/to/depth.png" + ... ) + + >>> # Access depth-specific fields + >>> depth_map = depth_input.depth_map + >>> depth_path = depth_input.depth_path + + Note: + This is an abstract base class and is not intended to be instantiated + directly. Concrete subclasses should implement all required validation + methods. + """ depth_map: FieldDescriptor[T | None] = FieldDescriptor(validator_name="_validate_depth_map") depth_path: FieldDescriptor[PathT | None] = FieldDescriptor(validator_name="_validate_depth_path") @@ -188,7 +332,47 @@ def _validate_depth_path(self, depth_path: PathT) -> PathT | None: @dataclass class _OutputFields(Generic[T, MaskT], ABC): - """Generic dataclass that defines the standard output fields.""" + """Generic dataclass that defines the standard output fields for Anomalib. + + This class defines the standard output fields used in Anomalib, including + anomaly maps, predicted scores, predicted masks, and predicted labels. + + Each field uses a ``FieldDescriptor`` with a corresponding validation method. + Subclasses must implement these abstract validation methods to ensure data + consistency with Anomalib's anomaly detection tasks. + + Examples: + Assuming a concrete implementation ``DummyOutput``: + + >>> class DummyOutput(_OutputFields): + ... def _validate_anomaly_map(self, anomaly_map): + ... return anomaly_map # Implement actual validation + ... def _validate_pred_score(self, pred_score): + ... return pred_score # Implement actual validation + ... def _validate_pred_mask(self, pred_mask): + ... return pred_mask # Implement actual validation + ... def _validate_pred_label(self, pred_label): + ... return pred_label # Implement actual validation + + >>> # Create an output instance with predictions + >>> output = DummyOutput( + ... anomaly_map=torch.rand(224, 224), + ... pred_score=0.7, + ... pred_mask=torch.rand(224, 224) > 0.5, + ... pred_label=1 + ... ) + + >>> # Access individual fields + >>> anomaly_map = output.anomaly_map + >>> score = output.pred_score + >>> mask = output.pred_mask + >>> label = output.pred_label + + Note: + This is an abstract base class and is not intended to be instantiated + directly. Concrete subclasses should implement all required validation + methods. + """ anomaly_map: FieldDescriptor[MaskT | None] = FieldDescriptor(validator_name="_validate_anomaly_map") pred_score: FieldDescriptor[T | None] = FieldDescriptor(validator_name="_validate_pred_score") @@ -218,7 +402,34 @@ def _validate_pred_label(self, pred_label: T) -> T | None: @dataclass class UpdateMixin: - """Mixin class for dataclasses that allows for in-place replacement of attributes.""" + """Mixin class for dataclasses that allows for in-place replacement of attributes. + + This mixin class provides a method for updating dataclass instances in place or + by creating a new instance. It ensures that the updated instance is reinitialized + by calling the ``__post_init__`` method if it exists. + + Examples: + Assuming a dataclass `DummyItem` that uses UpdateMixin: + + >>> item = DummyItem(image=torch.rand(3, 224, 224), label=0) + + >>> # In-place update + >>> item.update(label=1, pred_score=0.9) + >>> print(item.label, item.pred_score) + 1 0.9 + + >>> # Create a new instance with updates + >>> new_item = item.update(in_place=False, image=torch.rand(3, 224, 224)) + >>> print(id(item) != id(new_item)) + True + + >>> # Update with multiple fields + >>> item.update(label=2, pred_score=0.8, anomaly_map=torch.rand(224, 224)) + + The `update` method can be used to modify single or multiple fields, either + in-place or by creating a new instance. This flexibility is particularly useful + in data processing pipelines and when working with model predictions in Anomalib. + """ def update(self, in_place: bool = True, **changes) -> Any: # noqa: ANN401 """Replace fields in place and call __post_init__ to reinitialize the instance. @@ -247,7 +458,46 @@ class _GenericItem( _OutputFields[T, MaskT], _InputFields[T, ImageT, MaskT, PathT], ): - """Generic dataclass for a dataset item.""" + """Generic dataclass for a single item in Anomalib datasets. + + This class combines input and output fields for anomaly detection tasks, + providing a comprehensive representation of a single data item. It inherits + from ``_InputFields`` for standard input data and ``_OutputFields`` for + prediction results. + + The class also includes the ``UpdateMixin``, allowing for easy updates of + field values. This is particularly useful during data processing pipelines + and when working with model predictions. + + By using generic types, this class can accommodate various data types used + in different Anomalib models and datasets, ensuring flexibility and + reusability across the library. + + Examples: + Assuming a concrete implementation ``DummyItem``: + + >>> class DummyItem(_GenericItem): + ... def _validate_image(self, image): + ... return image # Implement actual validation + ... # Implement other required methods + + >>> # Create a generic item instance + >>> item = DummyItem( + ... image=torch.rand(3, 224, 224), + ... gt_label=0, + ... pred_score=0.3, + ... anomaly_map=torch.rand(224, 224) + ... ) + + >>> # Access and update fields + >>> image = item.image + >>> item.update(pred_score=0.8, pred_label=1) + + Note: + This is an abstract base class and is not intended to be instantiated + directly. Concrete subclasses should implement all required validation + methods. + """ @dataclass @@ -257,7 +507,47 @@ class _GenericBatch( _OutputFields[T, MaskT], _InputFields[T, ImageT, MaskT, PathT], ): - """Generic dataclass for a batch.""" + """Generic dataclass for a batch of items in Anomalib datasets. + + This class represents a batch of data items, combining both input and output + fields for anomaly detection tasks. It inherits from ``_InputFields`` for + input data and ``_OutputFields`` for prediction results, allowing it to + handle both training data and model outputs. + + The class includes the ``UpdateMixin``, enabling easy updates of field values + across the entire batch. This is particularly useful for in-place modifications + during data processing or when updating predictions. + + Examples: + Assuming a concrete implementation ``DummyBatch``: + + >>> class DummyBatch(_GenericBatch): + ... def _validate_image(self, image): + ... return image # Implement actual validation + ... # Implement other required methods + + >>> # Create a batch with input data + >>> batch = DummyBatch( + ... image=torch.rand(32, 3, 224, 224), + ... gt_label=torch.randint(0, 2, (32,)) + ... ) + + >>> # Update the entire batch with new predictions + >>> batch.update( + ... pred_score=torch.rand(32), + ... anomaly_map=torch.rand(32, 224, 224) + ... ) + + >>> # Access individual fields + >>> images = batch.image + >>> labels = batch.gt_label + >>> predictions = batch.pred_score + + Note: + This is an abstract base class and is not intended to be instantiated + directly. Concrete subclasses should implement all required validation + methods. + """ ItemT = TypeVar("ItemT", bound="_GenericItem") @@ -265,7 +555,41 @@ class _GenericBatch( @dataclass class BatchIterateMixin(Generic[ItemT]): - """Generic dataclass for a batch.""" + """Mixin class for iterating over batches of items in Anomalib datasets. + + This class provides functionality to iterate over individual items within a + batch, convert batches to lists of items, and determine batch sizes. It's + designed to work with Anomalib's batch processing pipelines. + + The mixin requires subclasses to define an ``item_class`` attribute, which + specifies the class used for individual items in the batch. This ensures + type consistency when iterating or converting batches. + + Key features include: + - Iteration over batch items + - Conversion of batches to lists of individual items + - Batch size determination + - A class method for collating individual items into a batch + + Examples: + Assuming a subclass `DummyBatch` with `DummyItem` as its item_class: + + >>> batch = DummyBatch(images=[...], labels=[...]) + >>> for item in batch: + ... process_item(item) # Iterate over items + + >>> item_list = batch.items # Convert batch to list of items + >>> type(item_list[0]) + + + >>> batch_size = len(batch) # Get batch size + + >>> items = [DummyItem(...) for _ in range(5)] + >>> new_batch = DummyBatch.collate(items) # Collate items into a batch + + This mixin enhances batch handling capabilities in Anomalib, facilitating + efficient data processing and model interactions. + """ item_class: ClassVar[Callable] diff --git a/src/anomalib/dataclasses/numpy.py b/src/anomalib/dataclasses/numpy.py index eb265d92d9..6e1ea1a21a 100644 --- a/src/anomalib/dataclasses/numpy.py +++ b/src/anomalib/dataclasses/numpy.py @@ -1,4 +1,18 @@ -"""Dataclasses for numpy data.""" +"""Numpy-based dataclasses for Anomalib. + +This module provides numpy-based implementations of the generic dataclasses +used in Anomalib. These classes are designed to work with numpy arrays for +efficient data handling and processing in anomaly detection tasks. + +The module includes the following main classes: + +- NumpyItem: Represents a single item in Anomalib datasets using numpy arrays. +- NumpyBatch: Represents a batch of items in Anomalib datasets using numpy arrays. +- NumpyImageItem: Represents a single image item with additional image-specific fields. +- NumpyImageBatch: Represents a batch of image items with batch operations. +- NumpyVideoItem: Represents a single video item with video-specific fields. +- NumpyVideoBatch: Represents a batch of video items with video-specific operations. +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -12,21 +26,49 @@ @dataclass class NumpyItem(_GenericItem[np.ndarray, np.ndarray, np.ndarray, str]): - """Dataclass for numpy item.""" + """Dataclass for a single item in Anomalib datasets using numpy arrays. + + This class extends _GenericItem for numpy-based data representation. It includes + both input data (e.g., images, labels) and output data (e.g., predictions, + anomaly maps) as numpy arrays. It is suitable for numpy-based processing + pipelines in Anomalib. + """ @dataclass class NumpyBatch(_GenericBatch[np.ndarray, np.ndarray, np.ndarray, list[str]]): - """Dataclass for numpy batch.""" + """Dataclass for a batch of items in Anomalib datasets using numpy arrays. + + This class extends _GenericBatch for batches of numpy-based data. It represents + multiple data points for batch processing in anomaly detection tasks. It includes + an additional dimension for batch size in all tensor-like fields. + """ -# torch image outputs @dataclass -class NumpyImageItem( - _ImageInputFields[str], - NumpyItem, -): - """Dataclass for numpy image output item.""" +class NumpyImageItem(_ImageInputFields[str], NumpyItem): + """Dataclass for a single image item in Anomalib datasets using numpy arrays. + + This class combines _ImageInputFields and NumpyItem for image-based anomaly detection. + It includes image-specific fields and validation methods to ensure proper formatting + for Anomalib's image-based models. + + Examples: + >>> item = NumpyImageItem( + ... image=np.random.rand(224, 224, 3), + ... gt_label=np.array(1), + ... gt_mask=np.random.rand(224, 224) > 0.5, + ... anomaly_map=np.random.rand(224, 224), + ... pred_score=np.array(0.7), + ... pred_label=np.array(1), + ... image_path="path/to/image.jpg" + ... ) + + >>> # Access fields + >>> image = item.image + >>> label = item.gt_label + >>> path = item.image_path + """ def _validate_image(self, image: np.ndarray) -> np.ndarray: assert image.ndim == 3, f"Expected 3D image, got {image.ndim}D image." @@ -77,12 +119,33 @@ def _validate_image_path(self, image_path: str) -> str: @dataclass -class NumpyImageBatch( - BatchIterateMixin[NumpyImageItem], - _ImageInputFields[list[str]], - NumpyBatch, -): - """Dataclass for numpy image output batch.""" +class NumpyImageBatch(BatchIterateMixin[NumpyImageItem], _ImageInputFields[list[str]], NumpyBatch): + """Dataclass for a batch of image items in Anomalib datasets using numpy arrays. + + This class combines BatchIterateMixin, _ImageInputFields, and NumpyBatch for batches + of image data. It supports batch operations and iteration over individual NumpyImageItems. + It ensures proper formatting for Anomalib's image-based models. + + Examples: + >>> batch = NumpyImageBatch( + ... image=np.random.rand(32, 224, 224, 3), + ... gt_label=np.random.randint(0, 2, (32,)), + ... gt_mask=np.random.rand(32, 224, 224) > 0.5, + ... anomaly_map=np.random.rand(32, 224, 224), + ... pred_score=np.random.rand(32), + ... pred_label=np.random.randint(0, 2, (32,)), + ... image_path=["path/to/image_{}.jpg".format(i) for i in range(32)] + ... ) + + >>> # Access batch fields + >>> images = batch.image + >>> labels = batch.gt_label + >>> paths = batch.image_path + + >>> # Iterate over items in the batch + >>> for item in batch: + ... process_item(item) + """ item_class = NumpyImageItem @@ -114,13 +177,14 @@ def _validate_image_path(self, image_path: list[str]) -> list[str]: return image_path -# torch video outputs @dataclass -class NumpyVideoItem( - _VideoInputFields[np.ndarray, np.ndarray, np.ndarray, str], - NumpyItem, -): - """Dataclass for numpy video output item.""" +class NumpyVideoItem(_VideoInputFields[np.ndarray, np.ndarray, np.ndarray, str], NumpyItem): + """Dataclass for a single video item in Anomalib datasets using numpy arrays. + + This class combines _VideoInputFields and NumpyItem for video-based anomaly detection. + It includes video-specific fields and validation methods to ensure proper formatting + for Anomalib's video-based models. + """ def _validate_image(self, image: np.ndarray) -> np.ndarray: return image @@ -141,7 +205,12 @@ class NumpyVideoBatch( _VideoInputFields[np.ndarray, np.ndarray, np.ndarray, list[str]], NumpyBatch, ): - """Dataclass for numpy video output batch.""" + """Dataclass for a batch of video items in Anomalib datasets using numpy arrays. + + This class combines BatchIterateMixin, _VideoInputFields, and NumpyBatch for batches + of video data. It supports batch operations and iteration over individual NumpyVideoItems. + It ensures proper formatting for Anomalib's video-based models. + """ item_class = NumpyVideoItem @@ -159,12 +228,3 @@ def _validate_mask_path(self, mask_path: list[str]) -> list[str]: def _validate_anomaly_map(self, anomaly_map: np.ndarray) -> np.ndarray: return anomaly_map - - def _validate_pred_score(self, pred_score: np.ndarray) -> np.ndarray: - return pred_score - - def _validate_pred_mask(self, pred_mask: np.ndarray) -> np.ndarray: - return pred_mask - - def _validate_pred_label(self, pred_label: np.ndarray) -> np.ndarray: - return pred_label diff --git a/src/anomalib/dataclasses/torch.py b/src/anomalib/dataclasses/torch.py index dab9cac066..7bc4e93c0c 100644 --- a/src/anomalib/dataclasses/torch.py +++ b/src/anomalib/dataclasses/torch.py @@ -1,4 +1,18 @@ -"""Dataclasses for torch inputs and outputs.""" +"""Torch-based dataclasses for Anomalib. + +This module provides PyTorch-based implementations of the generic dataclasses +used in Anomalib. These classes are designed to work with PyTorch tensors for +efficient data handling and processing in anomaly detection tasks. + +These classes extend the generic dataclasses defined in the Anomalib framework, +providing concrete implementations that use PyTorch tensors for tensor-like data. +They include methods for data validation and support operations specific to +image, video, and depth data in the context of anomaly detection. + +Note: + When using these classes, ensure that the input data is in the correct + format (PyTorch tensors with appropriate shapes) to avoid validation errors. +""" # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 @@ -36,10 +50,26 @@ class InferenceBatch(NamedTuple): @dataclass -class ToNumpyMixin( - Generic[NumpyT], -): - """Mixin for converting torch-based dataclasses to numpy.""" +class ToNumpyMixin(Generic[NumpyT]): + """Mixin for converting torch-based dataclasses to numpy. + + This mixin provides functionality to convert PyTorch tensor data to numpy arrays. + It requires the subclass to define a 'numpy_class' attribute specifying the + corresponding numpy-based class. + + Examples: + >>> from anomalib.dataclasses.numpy import NumpyImageItem + >>> @dataclass + ... class TorchImageItem(ToNumpyMixin[NumpyImageItem]): + ... numpy_class = NumpyImageItem + ... image: torch.Tensor + ... gt_label: torch.Tensor + + >>> torch_item = TorchImageItem(image=torch.rand(3, 224, 224), gt_label=torch.tensor(1)) + >>> numpy_item = torch_item.to_numpy() + >>> isinstance(numpy_item, NumpyImageItem) + True + """ numpy_class: ClassVar[Callable] @@ -63,22 +93,83 @@ def to_numpy(self) -> NumpyT: @dataclass class DatasetItem(Generic[ImageT], _GenericItem[torch.Tensor, ImageT, Mask, str]): - """Dataclass for torch item.""" + """Base dataclass for individual items in Anomalib datasets using PyTorch tensors. + + This class extends the generic _GenericItem class to provide a PyTorch-specific + implementation for single data items in Anomalib datasets. It is designed to + handle various types of data (e.g., images, labels, masks) represented as + PyTorch tensors. + + The class uses generic types to allow flexibility in the image representation, + which can vary depending on the specific use case (e.g., standard images, video clips). + + Attributes: + Inherited from _GenericItem, with PyTorch tensor and Mask types. + + Note: + This class is typically subclassed to create more specific item types + (e.g., ImageItem, VideoItem) with additional fields and methods. + """ @dataclass class Batch(Generic[ImageT], _GenericBatch[torch.Tensor, ImageT, Mask, list[str]]): - """Dataclass for torch batch.""" + """Base dataclass for batches of items in Anomalib datasets using PyTorch tensors. + + This class extends the generic _GenericBatch class to provide a PyTorch-specific + implementation for batches of data in Anomalib datasets. It is designed to + handle collections of data items (e.g., multiple images, labels, masks) + represented as PyTorch tensors. + + The class uses generic types to allow flexibility in the image representation, + which can vary depending on the specific use case (e.g., standard images, video clips). + + Attributes: + Inherited from _GenericBatch, with PyTorch tensor and Mask types. + + Note: + This class is typically subclassed to create more specific batch types + (e.g., ImageBatch, VideoBatch) with additional fields and methods. + """ -# torch image outputs @dataclass class ImageItem( ToNumpyMixin[NumpyImageItem], _ImageInputFields[str], DatasetItem[Image], ): - """Dataclass for torch image output item.""" + """Dataclass for individual image items in Anomalib datasets using PyTorch tensors. + + This class combines the functionality of ToNumpyMixin, _ImageInputFields, and + DatasetItem to represent single image data points in Anomalib. It includes + image-specific fields and provides methods for data validation and conversion + to numpy format. + + The class is designed to work with PyTorch tensors and includes fields for + the image data, ground truth labels and masks, anomaly maps, and related metadata. + + Attributes: + Inherited from _ImageInputFields and DatasetItem. + + Methods: + Inherited from ToNumpyMixin, including to_numpy() for conversion to numpy format. + + Examples: + >>> item = ImageItem( + ... image=torch.rand(3, 224, 224), + ... gt_label=torch.tensor(1), + ... gt_mask=torch.rand(224, 224) > 0.5, + ... image_path="path/to/image.jpg" + ... ) + + >>> print(item.image.shape) + torch.Size([3, 224, 224]) + + >>> numpy_item = item.to_numpy() + >>> print(type(numpy_item)) + + """ numpy_class = NumpyImageItem @@ -186,7 +277,35 @@ class ImageBatch( _ImageInputFields[list[str]], Batch[Image], ): - """Dataclass for torch image output batch.""" + """Dataclass for batches of image items in Anomalib datasets using PyTorch tensors. + + This class combines the functionality of ``ToNumpyMixin``, ``BatchIterateMixin``, + ``_ImageInputFields``, and ``Batch`` to represent collections of image data points in Anomalib. + It includes image-specific fields and provides methods for batch operations, + iteration over individual items, and conversion to numpy format. + + The class is designed to work with PyTorch tensors and includes fields for + batches of image data, ground truth labels and masks, anomaly maps, and related metadata. + + Examples: + >>> batch = ImageBatch( + ... image=torch.rand(32, 3, 224, 224), + ... gt_label=torch.randint(0, 2, (32,)), + ... gt_mask=torch.rand(32, 224, 224) > 0.5, + ... image_path=["path/to/image_{}.jpg".format(i) for i in range(32)] + ... ) + + >>> print(batch.image.shape) + torch.Size([32, 3, 224, 224]) + + >>> for item in batch: + ... print(item.image.shape) + torch.Size([3, 224, 224]) + + >>> numpy_batch = batch.to_numpy() + >>> print(type(numpy_batch)) + + """ item_class = ImageItem numpy_class = NumpyImageBatch @@ -296,7 +415,27 @@ class VideoItem( _VideoInputFields[torch.Tensor, Video, Mask, str], DatasetItem[Video], ): - """Dataclass for torch video output item.""" + """Dataclass for individual video items in Anomalib datasets using PyTorch tensors. + + This class represents a single video item in Anomalib datasets using PyTorch tensors. + It combines the functionality of ToNumpyMixin, _VideoInputFields, and DatasetItem + to handle video data, including frames, labels, masks, and metadata. + + Examples: + >>> item = VideoItem( + ... image=torch.rand(10, 3, 224, 224), # 10 frames + ... gt_label=torch.tensor(1), + ... gt_mask=torch.rand(10, 224, 224) > 0.5, + ... video_path="path/to/video.mp4" + ... ) + + >>> print(item.image.shape) + torch.Size([10, 3, 224, 224]) + + >>> numpy_item = item.to_numpy() + >>> print(type(numpy_item)) + + """ numpy_class = NumpyVideoItem @@ -352,7 +491,31 @@ class VideoBatch( _VideoInputFields[torch.Tensor, Video, Mask, list[str]], Batch[Video], ): - """Dataclass for torch video output batch.""" + """Dataclass for batches of video items in Anomalib datasets using PyTorch tensors. + + This class represents a batch of video items in Anomalib datasets using PyTorch tensors. + It combines the functionality of ToNumpyMixin, BatchIterateMixin, _VideoInputFields, + and Batch to handle batches of video data, including frames, labels, masks, and metadata. + + Examples: + >>> batch = VideoBatch( + ... image=torch.rand(32, 10, 3, 224, 224), # 32 videos, 10 frames each + ... gt_label=torch.randint(0, 2, (32,)), + ... gt_mask=torch.rand(32, 10, 224, 224) > 0.5, + ... video_path=["path/to/video_{}.mp4".format(i) for i in range(32)] + ... ) + + >>> print(batch.image.shape) + torch.Size([32, 10, 3, 224, 224]) + + >>> for item in batch: + ... print(item.image.shape) + torch.Size([10, 3, 224, 224]) + + >>> numpy_batch = batch.to_numpy() + >>> print(type(numpy_batch)) + + """ item_class = VideoItem numpy_class = NumpyVideoBatch @@ -404,7 +567,24 @@ class DepthItem( _DepthInputFields[torch.Tensor, str], DatasetItem[Image], ): - """Dataclass for torch depth output item.""" + """Dataclass for individual depth items in Anomalib datasets using PyTorch tensors. + + This class represents a single depth item in Anomalib datasets using PyTorch tensors. + It combines the functionality of ToNumpyMixin, _DepthInputFields, and DatasetItem + to handle depth data, including depth maps, labels, and metadata. + + Examples: + >>> item = DepthItem( + ... image=torch.rand(3, 224, 224), + ... gt_label=torch.tensor(1), + ... depth_map=torch.rand(224, 224), + ... image_path="path/to/image.jpg", + ... depth_path="path/to/depth.png" + ... ) + + >>> print(item.image.shape, item.depth_map.shape) + torch.Size([3, 224, 224]) torch.Size([224, 224]) + """ numpy_class = NumpyImageItem @@ -448,7 +628,28 @@ class DepthBatch( _DepthInputFields[torch.Tensor, list[str]], Batch[Image], ): - """Dataclass for torch depth output batch.""" + """Dataclass for batches of depth items in Anomalib datasets using PyTorch tensors. + + This class represents a batch of depth items in Anomalib datasets using PyTorch tensors. + It combines the functionality of BatchIterateMixin, _DepthInputFields, and Batch + to handle batches of depth data, including depth maps, labels, and metadata. + + Examples: + >>> batch = DepthBatch( + ... image=torch.rand(32, 3, 224, 224), + ... gt_label=torch.randint(0, 2, (32,)), + ... depth_map=torch.rand(32, 224, 224), + ... image_path=["path/to/image_{}.jpg".format(i) for i in range(32)], + ... depth_path=["path/to/depth_{}.png".format(i) for i in range(32)] + ... ) + + >>> print(batch.image.shape, batch.depth_map.shape) + torch.Size([32, 3, 224, 224]) torch.Size([32, 224, 224]) + + >>> for item in batch: + ... print(item.image.shape, item.depth_map.shape) + torch.Size([3, 224, 224]) torch.Size([224, 224]) + """ item_class = DepthItem