diff --git a/README.md b/README.md index a733a8fb0f..d1039af536 100644 --- a/README.md +++ b/README.md @@ -192,7 +192,6 @@ Example OpenVINO Inference: ```bash python tools/inference/openvino_inference.py \ - --config src/anomalib/models/padim/config.yaml \ --weights results/padim/mvtec/bottle/run/openvino/model.bin \ --metadata results/padim/mvtec/bottle/run/openvino/metadata.json \ --input datasets/MVTec/bottle/test/broken_large/000.png \ @@ -207,7 +206,6 @@ A quick example: ```bash python tools/inference/gradio_inference.py \ - --config src/anomalib/models/padim/config.yaml \ --weights results/padim/mvtec/bottle/run/weights/model.ckpt ``` diff --git a/notebooks/000_getting_started/001_getting_started.ipynb b/notebooks/000_getting_started/001_getting_started.ipynb index a09f82f585..c278e2b863 100644 --- a/notebooks/000_getting_started/001_getting_started.ipynb +++ b/notebooks/000_getting_started/001_getting_started.ipynb @@ -433,8 +433,8 @@ } ], "source": [ - "openvino_model_path = output_path / \"openvino\" / \"model.bin\"\n", - "metadata_path = output_path / \"openvino\" / \"metadata.json\"\n", + "openvino_model_path = output_path / \"weights\" / \"openvino\" / \"model.bin\"\n", + "metadata_path = output_path / \"weights\" / \"openvino\" / \"metadata.json\"\n", "print(openvino_model_path.exists(), metadata_path.exists())" ] }, diff --git a/notebooks/400_openvino/401_nncf.ipynb b/notebooks/400_openvino/401_nncf.ipynb index 2b1e381c44..29fc26bb08 100644 --- a/notebooks/400_openvino/401_nncf.ipynb +++ b/notebooks/400_openvino/401_nncf.ipynb @@ -152,7 +152,7 @@ "\n", "```yaml\n", "optimization:\n", - " export_mode: null #options: onnx, openvino\n", + " export_mode: null # options: torch, onnx, openvino\n", " nncf:\n", " apply: true\n", " input_info:\n", @@ -282,7 +282,7 @@ "\n", "```yaml\n", "optimization:\n", - " export_mode: null #options: onnx, openvino\n", + " export_mode: null # options: torch, onnx, openvino\n", " nncf:\n", " apply: true\n", " input_info:\n", diff --git a/src/anomalib/deploy/export.py b/src/anomalib/deploy/export.py index cd30680fbd..f0cfcff431 100644 --- a/src/anomalib/deploy/export.py +++ b/src/anomalib/deploy/export.py @@ -25,6 +25,7 @@ class ExportMode(str, Enum): ONNX = "onnx" OPENVINO = "openvino" + TORCH = "torch" def get_model_metadata(model: AnomalyModule) -> dict[str, Tensor]: @@ -59,6 +60,7 @@ def get_metadata(task: TaskType, transform: dict[str, Any], model: AnomalyModule task (TaskType): Task type. transform (dict[str, Any]): Transform used for the model. model (AnomalyModule): Model to export. + export_mode (ExportMode): Mode to export the model. Torch, ONNX or OpenVINO. Returns: dict[str, Any]: Metadata for the exported model. @@ -90,20 +92,44 @@ def export( transform (dict[str, Any]): Data transforms (augmentatiions) used for the model. input_size (tuple[int, int]): Input size of the model. model (AnomalyModule): Anomaly model to export. - export_mode (ExportMode): Mode to export the model. ONNX or OpenVINO. - export_root (str | Path): Path to exported ONNX/OpenVINO IR. + export_mode (ExportMode): Mode to export the model. Torch, ONNX or OpenVINO. + export_root (str | Path): Path to exported Torch, ONNX or OpenVINO IR. """ - # Write metadata to json file. The file is written in the same directory as the target model. - export_path = Path(export_root) / export_mode.value + # Create export directory. + export_path = Path(export_root) / "weights" / export_mode.value export_path.mkdir(parents=True, exist_ok=True) - with (Path(export_path) / "metadata.json").open("w", encoding="utf-8") as metadata_file: - metadata = get_metadata(task, transform, model) - json.dump(metadata, metadata_file, ensure_ascii=False, indent=4) - - # Export model to onnx and convert to OpenVINO IR if export mode is set to OpenVINO. - onnx_path = export_to_onnx(model, input_size, export_path) - if export_mode == ExportMode.OPENVINO: - export_to_openvino(export_path, onnx_path) + + # Get metadata. + metadata = get_metadata(task, transform, model) + + if export_mode == ExportMode.TORCH: + export_to_torch(model, metadata, export_path) + + elif export_mode in (ExportMode.ONNX, ExportMode.OPENVINO): + # Write metadata to json file. The file is written in the same directory as the target model. + with (Path(export_path) / "metadata.json").open("w", encoding="utf-8") as metadata_file: + json.dump(metadata, metadata_file, ensure_ascii=False, indent=4) + + # Export model to onnx and convert to OpenVINO IR if export mode is set to OpenVINO. + onnx_path = export_to_onnx(model, input_size, export_path) + if export_mode == ExportMode.OPENVINO: + export_to_openvino(export_path, onnx_path) + + else: + raise ValueError(f"Unknown export mode {export_mode}") + + +def export_to_torch(model: AnomalyModule, metadata: dict[str, Any], export_path: Path) -> None: + """Export AnomalibModel to torch. + + Args: + model (AnomalyModule): Model to export. + export_path (Path): Path to the folder storing the exported model. + """ + torch.save( + obj={"model": model.model, "metadata": metadata}, + f=export_path / "model.pt", + ) def export_to_onnx(model: AnomalyModule, input_size: tuple[int, int], export_path: Path) -> Path: diff --git a/src/anomalib/deploy/inferencers/torch_inferencer.py b/src/anomalib/deploy/inferencers/torch_inferencer.py index e6da3df5a0..63d4555adc 100644 --- a/src/anomalib/deploy/inferencers/torch_inferencer.py +++ b/src/anomalib/deploy/inferencers/torch_inferencer.py @@ -8,19 +8,15 @@ from pathlib import Path from typing import Any +import albumentations as A import cv2 import numpy as np import torch -from omegaconf import DictConfig, ListConfig -from torch import Tensor +from omegaconf import DictConfig +from torch import Tensor, nn -from anomalib.config import get_configurable_parameters from anomalib.data import TaskType -from anomalib.data.utils import InputNormalizationMethod, get_transforms from anomalib.data.utils.boxes import masks_to_boxes -from anomalib.deploy.export import get_model_metadata -from anomalib.models import get_model -from anomalib.models.components import AnomalyModule from .base_inferencer import Inferencer @@ -29,38 +25,21 @@ class TorchInferencer(Inferencer): """PyTorch implementation for the inference. Args: - config (str | Path | DictConfig | ListConfig): Configurable parameters that are used - during the training stage. - model_source (str | Path | AnomalyModule): Path to the model ckpt file or the Anomaly model. - metadata_path (str | Path, optional): Path to metadata file. If none, it tries to load the params - from the model state_dict. Defaults to None. - device (str | None, optional): Device to use for inference. Options are auto, cpu, cuda. Defaults to "auto". + path (str | Path): Path to Torch model weights. + device (str): Device to use for inference. Options are auto, cpu, cuda. Defaults to "auto". """ def __init__( self, - config: str | Path | DictConfig | ListConfig, - model_source: str | Path | AnomalyModule, - metadata_path: str | Path | None = None, + path: str | Path, device: str = "auto", ) -> None: self.device = self._get_device(device) - # Check and load the configuration - if isinstance(config, (str, Path)): - self.config = get_configurable_parameters(config_path=config) - elif isinstance(config, (DictConfig, ListConfig)): - self.config = config - else: - raise ValueError(f"Unknown config type {type(config)}") - - # Check and load the model weights. - if isinstance(model_source, AnomalyModule): - self.model = model_source - else: - self.model = self.load_model(model_source) - - self.metadata = self._load_metadata(metadata_path) + # Load the model weights. + self.model = self.load_model(path) + self.metadata = self._load_metadata(path) + self.transform = A.from_dict(self.metadata["transform"]) @staticmethod def _get_device(device: str) -> torch.device: @@ -82,24 +61,18 @@ def _get_device(device: str) -> torch.device: return torch.device(device) def _load_metadata(self, path: str | Path | None = None) -> dict | DictConfig: - """Load metadata from file or from model state dict. + """Load metadata from file. Args: - path (str | Path | None, optional): Path to metadata file. If none, it tries to load the params - from the model state_dict. Defaults to None. + path (str | Path): Path to the model pt file. Returns: dict: Dictionary containing the metadata. """ - metadata: dict[str, float | np.ndarray | Tensor] | DictConfig - if path is None: - # Torch inferencer still reads metadata from the model. - metadata = get_model_metadata(self.model) - else: - metadata = super()._load_metadata(path) + metadata = torch.load(path, map_location=self.device)["metadata"] if path else {} return metadata - def load_model(self, path: str | Path) -> AnomalyModule: + def load_model(self, path: str | Path) -> nn.Module: """Load the PyTorch model. Args: @@ -108,8 +81,8 @@ def load_model(self, path: str | Path) -> AnomalyModule: Returns: (AnomalyModule): PyTorch Lightning model. """ - model = get_model(self.config) - model.load_state_dict(torch.load(path, map_location=self.device)["state_dict"]) + + model = torch.load(path, map_location=self.device)["model"] model.eval() return model.to(self.device) @@ -122,19 +95,7 @@ def pre_process(self, image: np.ndarray) -> Tensor: Returns: Tensor: pre-processed image. """ - transform_config = ( - self.config.dataset.transform_config.eval if "transform_config" in self.config.dataset.keys() else None - ) - - image_size = (self.config.dataset.image_size[0], self.config.dataset.image_size[1]) - center_crop = self.config.dataset.get("center_crop") - if center_crop is not None: - center_crop = tuple(center_crop) - normalization = InputNormalizationMethod(self.config.dataset.normalization) - transform = get_transforms( - config=transform_config, image_size=image_size, center_crop=center_crop, normalization=normalization - ) - processed_image = transform(image=image)["image"] + processed_image = self.transform(image=image)["image"] if len(processed_image) == 3: processed_image = processed_image.unsqueeze(0) @@ -209,7 +170,7 @@ def post_process(self, predictions: Tensor, metadata: dict | DictConfig | None = if pred_mask is not None: pred_mask = cv2.resize(pred_mask, (image_width, image_height)) - if self.config.dataset.task == TaskType.DETECTION: + if self.metadata["task"] == TaskType.DETECTION: pred_boxes = masks_to_boxes(torch.from_numpy(pred_mask))[0][0].numpy() box_labels = np.ones(pred_boxes.shape[0]) else: diff --git a/src/anomalib/models/cfa/config.yaml b/src/anomalib/models/cfa/config.yaml index 7a4fef0c93..fbf67fcb51 100644 --- a/src/anomalib/models/cfa/config.yaml +++ b/src/anomalib/models/cfa/config.yaml @@ -62,7 +62,7 @@ logging: log_graph: false # Logs the model graph to respective logger. optimization: - export_mode: null #options: onnx, openvino + export_mode: null # options: torch, onnx, openvino # PL Trainer Args. Don't add extra parameter here. trainer: diff --git a/src/anomalib/models/cflow/config.yaml b/src/anomalib/models/cflow/config.yaml index 4e7818a1a4..e4e630777d 100644 --- a/src/anomalib/models/cflow/config.yaml +++ b/src/anomalib/models/cflow/config.yaml @@ -68,7 +68,7 @@ logging: log_graph: false # Logs the model graph to respective logger. optimization: - export_mode: null #options: onnx, openvino + export_mode: null # options: torch, onnx, openvino # PL Trainer Args. Don't add extra parameter here. trainer: diff --git a/src/anomalib/models/cflow/utils.py b/src/anomalib/models/cflow/utils.py index 326668ce08..5856e0c268 100644 --- a/src/anomalib/models/cflow/utils.py +++ b/src/anomalib/models/cflow/utils.py @@ -11,9 +11,10 @@ import numpy as np import torch from FrEIA.framework import SequenceINN -from FrEIA.modules import AllInOneBlock from torch import Tensor, nn +from anomalib.models.components.flow import AllInOneBlock + logger = logging.getLogger(__name__) diff --git a/src/anomalib/models/components/flow/__init__.py b/src/anomalib/models/components/flow/__init__.py new file mode 100644 index 0000000000..dca2e7b9e6 --- /dev/null +++ b/src/anomalib/models/components/flow/__init__.py @@ -0,0 +1,8 @@ +"""All In One Block Layer.""" + +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .all_in_one_block import AllInOneBlock + +__all__ = ["AllInOneBlock"] diff --git a/src/anomalib/models/components/flow/all_in_one_block.py b/src/anomalib/models/components/flow/all_in_one_block.py new file mode 100644 index 0000000000..f2ab1e17c3 --- /dev/null +++ b/src/anomalib/models/components/flow/all_in_one_block.py @@ -0,0 +1,318 @@ +"""All In One Block Layer.""" + +# Copyright (c) https://github.com/vislearn/FrEIA +# SPDX-License-Identifier: MIT + +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import warnings +from typing import Callable + +import torch +import torch.nn.functional as F +from FrEIA.modules import InvertibleModule +from scipy.stats import special_ortho_group +from torch import Tensor, nn + + +def _global_scale_sigmoid_activation(input: Tensor) -> Tensor: + """Global scale sigmoid activation. + + Args: + input (Tensor): Input tensor + + Returns: + Tensor: Sigmoid activation + """ + return 10 * torch.sigmoid(input - 2.0) + + +def _global_scale_softplus_activation(input: Tensor) -> Tensor: + """Global scale softplus activation. + + Args: + input (Tensor): Input tensor + + Returns: + Tensor: Softplus activation + """ + softplus = nn.Softplus(beta=0.5) + return 0.1 * softplus(input) + + +def _global_scale_exp_activation(input: Tensor) -> Tensor: + """Global scale exponential activation. + + Args: + input (Tensor): Input tensor + + Returns: + Tensor: Exponential activation + """ + return torch.exp(input) + + +class AllInOneBlock(InvertibleModule): + """Module combining the most common operations in a normalizing flow or similar model. + + It combines affine coupling, permutation, and global affine transformation + ('ActNorm'). It can also be used as GIN coupling block, perform learned + householder permutations, and use an inverted pre-permutation. The affine + transformation includes a soft clamping mechanism, first used in Real-NVP. + The block as a whole performs the following computation: + + .. math:: + + y = V\\,R \\; \\Psi(s_\\mathrm{global}) \\odot \\mathrm{Coupling}\\Big(R^{-1} V^{-1} x\\Big)+ t_\\mathrm{global} + + - The inverse pre-permutation of x (i.e. :math:`R^{-1} V^{-1}`) is optional (see + ``reverse_permutation`` below). + - The learned householder reflection matrix + :math:`V` is also optional all together (see ``learned_householder_permutation`` + below). + - For the coupling, the input is split into :math:`x_1, x_2` along + the channel dimension. Then the output of the coupling operation is the + two halves :math:`u = \\mathrm{concat}(u_1, u_2)`. + + .. math:: + + u_1 &= x_1 \\odot \\exp \\Big( \\alpha \\; \\mathrm{tanh}\\big( s(x_2) \\big)\\Big) + t(x_2) \\\\ + u_2 &= x_2 + + Because :math:`\\mathrm{tanh}(s) \\in [-1, 1]`, this clamping mechanism prevents + exploding values in the exponential. The hyperparameter :math:`\\alpha` can be adjusted. + + """ + + def __init__( + self, + dims_in, + dims_c=[], + subnet_constructor: Callable | None = None, + affine_clamping: float = 2.0, + gin_block: bool = False, + global_affine_init: float = 1.0, + global_affine_type: str = "SOFTPLUS", + permute_soft: bool = False, + learned_householder_permutation: int = 0, + reverse_permutation: bool = False, + ): + """ + Args: + subnet_constructor: + class or callable ``f``, called as ``f(channels_in, channels_out)`` and + should return a torch.nn.Module. Predicts coupling coefficients :math:`s, t`. + affine_clamping: + clamp the output of the multiplicative coefficients before + exponentiation to +/- ``affine_clamping`` (see :math:`\\alpha` above). + gin_block: + Turn the block into a GIN block from Sorrenson et al, 2019. + Makes it so that the coupling operations as a whole is volume preserving. + global_affine_init: + Initial value for the global affine scaling :math:`s_\mathrm{global}`. + global_affine_init: + ``'SIGMOID'``, ``'SOFTPLUS'``, or ``'EXP'``. Defines the activation to be used + on the beta for the global affine scaling (:math:`\\Psi` above). + permute_soft: + bool, whether to sample the permutation matrix :math:`R` from :math:`SO(N)`, + or to use hard permutations instead. Note, ``permute_soft=True`` is very slow + when working with >512 dimensions. + learned_householder_permutation: + Int, if >0, turn on the matrix :math:`V` above, that represents + multiple learned householder reflections. Slow if large number. + Dubious whether it actually helps network performance. + reverse_permutation: + Reverse the permutation before the block, as introduced by Putzky + et al, 2019. Turns on the :math:`R^{-1} V^{-1}` pre-multiplication above. + """ + + super().__init__(dims_in, dims_c) + + channels = dims_in[0][0] + # rank of the tensors means 1d, 2d, 3d tensor etc. + self.input_rank = len(dims_in[0]) - 1 + # tuple containing all dims except for batch-dim (used at various points) + self.sum_dims = tuple(range(1, 2 + self.input_rank)) + + if len(dims_c) == 0: + self.conditional = False + self.condition_channels = 0 + else: + assert tuple(dims_c[0][1:]) == tuple( + dims_in[0][1:] + ), f"Dimensions of input and condition don't agree: {dims_c} vs {dims_in}." + self.conditional = True + self.condition_channels = sum(dc[0] for dc in dims_c) + + split_len1 = channels - channels // 2 + split_len2 = channels // 2 + self.splits = [split_len1, split_len2] + + try: + self.permute_function = {0: F.linear, 1: F.conv1d, 2: F.conv2d, 3: F.conv3d}[self.input_rank] + except KeyError: + raise ValueError(f"Data is {1 + self.input_rank}D. Must be 1D-4D.") + + self.in_channels = channels + self.clamp = affine_clamping + self.GIN = gin_block + self.reverse_pre_permute = reverse_permutation + self.householder = learned_householder_permutation + + if permute_soft and channels > 512: + warnings.warn( + ( + "Soft permutation will take a very long time to initialize " + f"with {channels} feature channels. Consider using hard permutation instead." + ) + ) + + # global_scale is used as the initial value for the global affine scale + # (pre-activation). It is computed such that + # global_scale_activation(global_scale) = global_affine_init + # the 'magic numbers' (specifically for sigmoid) scale the activation to + # a sensible range. + if global_affine_type == "SIGMOID": + global_scale = 2.0 - torch.log(torch.tensor([10.0 / global_affine_init - 1.0])) + self.global_scale_activation = _global_scale_sigmoid_activation + elif global_affine_type == "SOFTPLUS": + global_scale = 2.0 * torch.log(torch.exp(torch.tensor(0.5 * 10.0 * global_affine_init)) - 1) + self.global_scale_activation = _global_scale_softplus_activation + elif global_affine_type == "EXP": + global_scale = torch.log(torch.tensor(global_affine_init)) + self.global_scale_activation = _global_scale_exp_activation + else: + raise ValueError('Global affine activation must be "SIGMOID", "SOFTPLUS" or "EXP"') + + self.global_scale = nn.Parameter(torch.ones(1, self.in_channels, *([1] * self.input_rank)) * global_scale) + self.global_offset = nn.Parameter(torch.zeros(1, self.in_channels, *([1] * self.input_rank))) + + if permute_soft: + w = special_ortho_group.rvs(channels) + else: + indices = torch.randperm(channels) + w = torch.zeros((channels, channels)) + w[torch.arange(channels), indices] = 1.0 + + if self.householder: + # instead of just the permutation matrix w, the learned housholder + # permutation keeps track of reflection vectors vk, in addition to a + # random initial permutation w_0. + self.vk_householder = nn.Parameter(0.2 * torch.randn(self.householder, channels), requires_grad=True) + self.w_perm = None + self.w_perm_inv = None + self.w_0 = nn.Parameter(torch.FloatTensor(w), requires_grad=False) + else: + self.w_perm = nn.Parameter( + torch.FloatTensor(w).view(channels, channels, *([1] * self.input_rank)), requires_grad=False + ) + self.w_perm_inv = nn.Parameter( + torch.FloatTensor(w.T).view(channels, channels, *([1] * self.input_rank)), requires_grad=False + ) + + if subnet_constructor is None: + raise ValueError("Please supply a callable subnet_constructor" "function or object (see docstring)") + self.subnet = subnet_constructor(self.splits[0] + self.condition_channels, 2 * self.splits[1]) + self.last_jac = None + + def _construct_householder_permutation(self): + """Computes a permutation matrix from the reflection vectors that are + learned internally as nn.Parameters.""" + w = self.w_0 + for vk in self.vk_householder: + w = torch.mm(w, torch.eye(self.in_channels).to(w.device) - 2 * torch.ger(vk, vk) / torch.dot(vk, vk)) + + for i in range(self.input_rank): + w = w.unsqueeze(-1) + return w + + def _permute(self, x, rev=False): + """Performs the permutation and scaling after the coupling operation. + Returns transformed outputs and the LogJacDet of the scaling operation.""" + if self.GIN: + scale = 1.0 + perm_log_jac = 0.0 + else: + scale = self.global_scale_activation(self.global_scale) + perm_log_jac = torch.sum(torch.log(scale)) + + if rev: + return ((self.permute_function(x, self.w_perm_inv) - self.global_offset) / scale, perm_log_jac) + else: + return (self.permute_function(x * scale + self.global_offset, self.w_perm), perm_log_jac) + + def _pre_permute(self, x, rev=False): + """Permutes before the coupling block, only used if + reverse_permutation is set""" + if rev: + return self.permute_function(x, self.w_perm) + else: + return self.permute_function(x, self.w_perm_inv) + + def _affine(self, x, a, rev=False): + """Given the passive half, and the pre-activation outputs of the + coupling subnetwork, perform the affine coupling operation. + Returns both the transformed inputs and the LogJacDet.""" + + # the entire coupling coefficient tensor is scaled down by a + # factor of ten for stability and easier initialization. + a *= 0.1 + ch = x.shape[1] + + sub_jac = self.clamp * torch.tanh(a[:, :ch]) + if self.GIN: + sub_jac -= torch.mean(sub_jac, dim=self.sum_dims, keepdim=True) + + if not rev: + return (x * torch.exp(sub_jac) + a[:, ch:], torch.sum(sub_jac, dim=self.sum_dims)) + else: + return ((x - a[:, ch:]) * torch.exp(-sub_jac), -torch.sum(sub_jac, dim=self.sum_dims)) + + def forward(self, x, c=[], rev=False, jac=True): + """See base class docstring""" + if self.householder: + self.w_perm = self._construct_householder_permutation() + if rev or self.reverse_pre_permute: + self.w_perm_inv = self.w_perm.transpose(0, 1).contiguous() + + if rev: + x, global_scaling_jac = self._permute(x[0], rev=True) + x = (x,) + elif self.reverse_pre_permute: + x = (self._pre_permute(x[0], rev=False),) + + x1, x2 = torch.split(x[0], self.splits, dim=1) + + if self.conditional: + x1c = torch.cat([x1, *c], 1) + else: + x1c = x1 + + if not rev: + a1 = self.subnet(x1c) + x2, j2 = self._affine(x2, a1) + else: + a1 = self.subnet(x1c) + x2, j2 = self._affine(x2, a1, rev=True) + + log_jac_det = j2 + x_out = torch.cat((x1, x2), 1) + + if not rev: + x_out, global_scaling_jac = self._permute(x_out, rev=False) + elif self.reverse_pre_permute: + x_out = self._pre_permute(x_out, rev=True) + + # add the global scaling Jacobian to the total. + # trick to get the total number of non-channel dimensions: + # number of elements of the first channel of the first batch member + n_pixels = x_out[0, :1].numel() + log_jac_det += (-1) ** rev * n_pixels * global_scaling_jac + + return (x_out,), log_jac_det + + def output_dims(self, input_dims): + return input_dims diff --git a/src/anomalib/models/csflow/config.yaml b/src/anomalib/models/csflow/config.yaml index 60684ac1bd..f3691ed28f 100644 --- a/src/anomalib/models/csflow/config.yaml +++ b/src/anomalib/models/csflow/config.yaml @@ -70,7 +70,7 @@ logging: log_graph: false # Logs the model graph to respective logger. optimization: - export_mode: null #options: onnx, openvino + export_mode: null # options: torch, onnx, openvino # PL Trainer Args. Don't add extra parameter here. trainer: enable_checkpointing: true diff --git a/src/anomalib/models/dfkde/config.yaml b/src/anomalib/models/dfkde/config.yaml index 5cff3a35f2..04acc3bb35 100644 --- a/src/anomalib/models/dfkde/config.yaml +++ b/src/anomalib/models/dfkde/config.yaml @@ -56,7 +56,7 @@ logging: log_graph: false # Logs the model graph to respective logger. optimization: - export_mode: null #options: onnx, openvino + export_mode: null # options: torch, onnx, openvino # PL Trainer Args. Don't add extra parameter here. trainer: enable_checkpointing: true diff --git a/src/anomalib/models/dfm/config.yaml b/src/anomalib/models/dfm/config.yaml index b6110681eb..2430604b71 100755 --- a/src/anomalib/models/dfm/config.yaml +++ b/src/anomalib/models/dfm/config.yaml @@ -57,7 +57,7 @@ logging: log_graph: false # Logs the model graph to respective logger. optimization: - export_mode: null #options: onnx, openvino + export_mode: null # options: torch, onnx, openvino # PL Trainer Args. Don't add extra parameter here. trainer: enable_checkpointing: true diff --git a/src/anomalib/models/draem/config.yaml b/src/anomalib/models/draem/config.yaml index abdd9c8a1c..862fd1159b 100644 --- a/src/anomalib/models/draem/config.yaml +++ b/src/anomalib/models/draem/config.yaml @@ -65,7 +65,7 @@ logging: log_graph: false # Logs the model graph to respective logger. optimization: - export_mode: null #options: onnx, openvino + export_mode: null # options: torch, onnx, openvino # PL Trainer Args. Don't add extra parameter here. trainer: enable_checkpointing: true diff --git a/src/anomalib/models/fastflow/config.yaml b/src/anomalib/models/fastflow/config.yaml index b2d1e115c8..525b6a3412 100644 --- a/src/anomalib/models/fastflow/config.yaml +++ b/src/anomalib/models/fastflow/config.yaml @@ -68,7 +68,7 @@ logging: log_graph: false # Logs the model graph to respective logger. optimization: - export_mode: null #options: onnx, openvino + export_mode: null # options: torch, onnx, openvino # PL Trainer Args. Don't add extra parameter here. trainer: diff --git a/src/anomalib/models/fastflow/torch_model.py b/src/anomalib/models/fastflow/torch_model.py index 3720194686..89cea79b83 100644 --- a/src/anomalib/models/fastflow/torch_model.py +++ b/src/anomalib/models/fastflow/torch_model.py @@ -16,11 +16,11 @@ import timm import torch from FrEIA.framework import SequenceINN -from FrEIA.modules import AllInOneBlock from timm.models.cait import Cait from timm.models.vision_transformer import VisionTransformer from torch import Tensor, nn +from anomalib.models.components.flow import AllInOneBlock from anomalib.models.fastflow.anomaly_map import AnomalyMapGenerator diff --git a/src/anomalib/models/padim/config.yaml b/src/anomalib/models/padim/config.yaml index 23aa97a216..84b0d2d421 100644 --- a/src/anomalib/models/padim/config.yaml +++ b/src/anomalib/models/padim/config.yaml @@ -63,7 +63,7 @@ logging: log_graph: false # Logs the model graph to respective logger. optimization: - export_mode: null #options: onnx, openvino + export_mode: null # options: torch, onnx, openvino # PL Trainer Args. Don't add extra parameter here. trainer: diff --git a/src/anomalib/models/reverse_distillation/config.yaml b/src/anomalib/models/reverse_distillation/config.yaml index 1cf9c9097c..8724556725 100644 --- a/src/anomalib/models/reverse_distillation/config.yaml +++ b/src/anomalib/models/reverse_distillation/config.yaml @@ -72,7 +72,7 @@ logging: log_graph: false # Logs the model graph to respective logger. optimization: - export_mode: null #options: onnx, openvino + export_mode: null # options: torch, onnx, openvino # PL Trainer Args. Don't add extra parameter here. trainer: diff --git a/src/anomalib/models/rkde/config.yaml b/src/anomalib/models/rkde/config.yaml index f5bda6c533..4fc6a1a5d8 100644 --- a/src/anomalib/models/rkde/config.yaml +++ b/src/anomalib/models/rkde/config.yaml @@ -66,7 +66,7 @@ logging: log_graph: false # Logs the model graph to respective logger. optimization: - export_mode: null #options: onnx, openvino + export_mode: null # options: torch, onnx, openvino # PL Trainer Args. Don't add extra parameter here. trainer: enable_checkpointing: true diff --git a/src/anomalib/models/stfpm/config.yaml b/src/anomalib/models/stfpm/config.yaml index efdd9ee14b..1f92f3a187 100644 --- a/src/anomalib/models/stfpm/config.yaml +++ b/src/anomalib/models/stfpm/config.yaml @@ -70,7 +70,7 @@ logging: log_graph: false # Logs the model graph to respective logger. optimization: - export_mode: null #options: onnx, openvino + export_mode: null # options: torch, onnx, openvino # PL Trainer Args. Don't add extra parameter here. trainer: enable_checkpointing: true diff --git a/src/anomalib/utils/callbacks/__init__.py b/src/anomalib/utils/callbacks/__init__.py index 5ee1bfdefc..555404c8e3 100644 --- a/src/anomalib/utils/callbacks/__init__.py +++ b/src/anomalib/utils/callbacks/__init__.py @@ -61,7 +61,7 @@ def get_callbacks(config: DictConfig | ListConfig) -> list[Callback]: monitor_mode = "max" if "early_stopping" not in config.model.keys() else config.model.early_stopping.mode checkpoint = ModelCheckpoint( - dirpath=os.path.join(config.project.path, "weights"), + dirpath=os.path.join(config.project.path, "weights", "lightning"), filename="model", monitor=monitor_metric, mode=monitor_mode, diff --git a/src/anomalib/utils/sweep/helpers/inference.py b/src/anomalib/utils/sweep/helpers/inference.py index 7357051592..20f0cde30c 100644 --- a/src/anomalib/utils/sweep/helpers/inference.py +++ b/src/anomalib/utils/sweep/helpers/inference.py @@ -9,32 +9,32 @@ from pathlib import Path import torch -from omegaconf import DictConfig, ListConfig from torch.utils.data import Dataset from anomalib.deploy import OpenVINOInferencer, TorchInferencer -from anomalib.models.components import AnomalyModule -def get_torch_throughput(config: DictConfig | ListConfig, model: AnomalyModule, test_dataset: Dataset) -> float: +def get_torch_throughput(model_path: str | Path, test_dataset: Dataset, device: str) -> float: """Tests the model on dummy data. Images are passed sequentially to make the comparision with OpenVINO model fair. Args: - config (DictConfig | ListConfig): Model config. - model (Path): Model on which inference is called. + model_path (str, Path): Path to folder containing the Torch models. test_dataset (Dataset): The test dataset used as a reference for the mock dataset. + device (str): Device to use for inference. Options are auto, cpu, gpu, cuda. Returns: float: Inference throughput """ + model_path = Path(model_path) torch.set_grad_enabled(False) - model.eval() - device = config.trainer.accelerator if device == "gpu": device = "cuda" - inferencer = TorchInferencer(config, model.to(device), device=device) + inferencer = TorchInferencer( + path=model_path / "weights" / "torch" / "model.pt", + device=device, + ) start_time = time.time() for image_path in test_dataset.samples.image_path: inferencer.predict(image_path) @@ -47,17 +47,22 @@ def get_torch_throughput(config: DictConfig | ListConfig, model: AnomalyModule, return throughput -def get_openvino_throughput(model_path: Path, test_dataset: Dataset) -> float: +def get_openvino_throughput(model_path: str | Path, test_dataset: Dataset) -> float: """Runs the generated OpenVINO model on a dummy dataset to get throughput. Args: - model_path (Path): Path to folder containing the OpenVINO models. It then searches `model.xml` in the folder. + model_path (str, Path): Path to folder containing the OpenVINO models. It then searches `model.xml` in folder. test_dataset (Dataset): The test dataset used as a reference for the mock dataset. Returns: float: Inference throughput """ - inferencer = OpenVINOInferencer(model_path / "openvino" / "model.xml", model_path / "openvino" / "metadata.json") + model_path = Path(model_path) + + inferencer = OpenVINOInferencer( + path=model_path / "weights" / "openvino" / "model.xml", + metadata_path=model_path / "weights" / "openvino" / "metadata.json", + ) start_time = time.time() for image_path in test_dataset.samples.image_path: inferencer.predict(image_path) diff --git a/tests/pre_merge/deploy/test_inferencer.py b/tests/pre_merge/deploy/test_inferencer.py index 43114d516a..86bd89f785 100644 --- a/tests/pre_merge/deploy/test_inferencer.py +++ b/tests/pre_merge/deploy/test_inferencer.py @@ -101,11 +101,15 @@ def test_torch_inference( Args: model_name (str): Name of the model """ - model_config, model = generate_results_dir(model_name=model_name, dataset_path=path, task=task, category=category) - model.eval() + model_config, model = generate_results_dir( + model_name=model_name, dataset_path=path, task=task, category=category, export_mode="torch" + ) # Test torch inferencer - torch_inferencer = TorchInferencer(model_config, model, device="cpu") + torch_inferencer = TorchInferencer( + path=Path(model_config.project.path) / "weights" / "torch" / "model.pt", + device="cpu", + ) torch_dataloader = MockImageLoader(model_config.dataset.image_size, total_count=1) with torch.no_grad(): for image in torch_dataloader(): @@ -142,7 +146,9 @@ def test_openvino_inference( export_path = Path(model_config.project.path) # Test OpenVINO inferencer - openvino_inferencer = OpenVINOInferencer(export_path / "openvino/model.xml", export_path / "openvino/metadata.json") + openvino_inferencer = OpenVINOInferencer( + export_path / "weights/openvino/model.xml", export_path / "weights/openvino/metadata.json" + ) openvino_dataloader = MockImageLoader(model_config.dataset.image_size, total_count=1) for image in openvino_dataloader(): prediction = openvino_inferencer.predict(image) diff --git a/tests/pre_merge/utils/callbacks/export_callback/test_export.py b/tests/pre_merge/utils/callbacks/export_callback/test_export.py index d2c330a3dd..c3f8978827 100644 --- a/tests/pre_merge/utils/callbacks/export_callback/test_export.py +++ b/tests/pre_merge/utils/callbacks/export_callback/test_export.py @@ -64,8 +64,10 @@ def test_export_model_callback(dummy_datamodule: MVTec, export_mode): trainer.fit(model, datamodule=dummy_datamodule) if export_mode == ExportMode.OPENVINO: - assert os.path.exists(os.path.join(tmp_dir, "openvino/model.bin")), "Failed to generate OpenVINO model" + assert os.path.exists( + os.path.join(tmp_dir, "weights/openvino/model.bin") + ), "Failed to generate OpenVINO model" elif export_mode == ExportMode.ONNX: - assert os.path.exists(os.path.join(tmp_dir, "onnx/model.onnx")), "Failed to generate ONNX model" + assert os.path.exists(os.path.join(tmp_dir, "weights/onnx/model.onnx")), "Failed to generate ONNX model" else: raise ValueError(f"Unknown export_mode {export_mode}. Supported modes: onnx or openvino.") diff --git a/tools/benchmarking/benchmark.py b/tools/benchmarking/benchmark.py index 13b3d959ab..dfdec0954c 100644 --- a/tools/benchmarking/benchmark.py +++ b/tools/benchmarking/benchmark.py @@ -117,19 +117,15 @@ def get_single_model_metrics(model_config: DictConfig | ListConfig, openvino_met openvino_throughput = float("nan") if openvino_metrics: # Create dirs for openvino model export - openvino_export_path = project_path / Path("exported_models") - openvino_export_path.mkdir(parents=True, exist_ok=True) export( task=model_config.dataset.task, transform=trainer.datamodule.test_data.transform.to_dict(), input_size=model_config.model.input_size, model=model, export_mode=ExportMode.OPENVINO, - export_root=openvino_export_path, - ) - openvino_throughput = get_openvino_throughput( - model_path=openvino_export_path, test_dataset=datamodule.test_data + export_root=project_path, ) + openvino_throughput = get_openvino_throughput(model_path=project_path, test_dataset=datamodule.test_data) # arrange the data data = { diff --git a/tools/inference/gradio_inference.py b/tools/inference/gradio_inference.py index 1600621439..048b0bbc5f 100644 --- a/tools/inference/gradio_inference.py +++ b/tools/inference/gradio_inference.py @@ -27,14 +27,12 @@ def get_args() -> Namespace: Example for Torch Inference. >>> python tools/inference/gradio_inference.py \ - ... --config ./anomalib/models/padim/config.yaml \ - ... --weights ./results/padim/mvtec/bottle/weights/model.ckpt + ... --weights ./results/padim/mvtec/bottle/weights/torch/model.pt Returns: Namespace: List of arguments. """ parser = ArgumentParser() - parser.add_argument("--config", type=Path, required=False, help="Path to a config file") parser.add_argument("--weights", type=Path, required=True, help="Path to model weights") parser.add_argument("--metadata", type=Path, required=False, help="Path to a JSON file containing the metadata.") parser.add_argument("--share", type=bool, required=False, default=False, help="Share Gradio `share_url`") @@ -42,11 +40,10 @@ def get_args() -> Namespace: return parser.parse_args() -def get_inferencer(config_path: Path, weight_path: Path, metadata_path: Path | None = None) -> Inferencer: +def get_inferencer(weight_path: Path, metadata_path: Path | None = None) -> Inferencer: """Parse args and open inferencer. Args: - config_path (Path): Path to model configuration file or the name of the model. weight_path (Path): Path to model weights. metadata_path (Path | None, optional): Metadata is required for OpenVINO models. Defaults to None. @@ -62,12 +59,9 @@ def get_inferencer(config_path: Path, weight_path: Path, metadata_path: Path | N extension = weight_path.suffix inferencer: Inferencer module = import_module("anomalib.deploy") - if extension in (".ckpt"): - if config_path is None: - raise ValueError("When using Torch Inferencer, the following arguments are required: --config") - + if extension in (".pt", ".pth", ".ckpt"): torch_inferencer = getattr(module, "TorchInferencer") - inferencer = torch_inferencer(config=config_path, model_source=weight_path, metadata_path=metadata_path) + inferencer = torch_inferencer(path=weight_path) elif extension in (".onnx", ".bin", ".xml"): if metadata_path is None: @@ -103,7 +97,7 @@ def infer(image: np.ndarray, inferencer: Inferencer) -> tuple[np.ndarray, np.nda if __name__ == "__main__": args = get_args() - gradio_inferencer = get_inferencer(args.config, args.weights, args.metadata) + gradio_inferencer = get_inferencer(args.weights, args.metadata) interface = gr.Interface( fn=lambda image: infer(image, gradio_inferencer), diff --git a/tools/inference/torch_inference.py b/tools/inference/torch_inference.py index e61c69dd9d..3848408728 100644 --- a/tools/inference/torch_inference.py +++ b/tools/inference/torch_inference.py @@ -1,6 +1,6 @@ """Anomalib Torch Inferencer Script. -This script performs torch inference by reading model config files and weights +This script performs torch inference by reading model weights from command line, and show the visualization results. """ @@ -29,7 +29,6 @@ def get_args() -> Namespace: Namespace: List of arguments. """ parser = ArgumentParser() - parser.add_argument("--config", type=Path, required=True, help="Path to a config file") parser.add_argument("--weights", type=Path, required=True, help="Path to model weights") parser.add_argument("--input", type=Path, required=True, help="Path to an image to infer.") parser.add_argument("--output", type=Path, required=False, help="Path to save the output image.") @@ -74,15 +73,13 @@ def infer() -> None: Show/save the output if path is to an image. If the path is a directory, go over each image in the directory. """ - # Get the command line arguments, and config from the config.yaml file. - # This config file is also used for training and contains all the relevant - # information regarding the data, model, train and inference details. + # Get the command line arguments. args = get_args() torch.set_grad_enabled(False) # Create the inferencer and visualizer. - inferencer = TorchInferencer(config=args.config, model_source=args.weights, device=args.device) + inferencer = TorchInferencer(path=args.weights, device=args.device) visualizer = Visualizer(mode=args.visualization_mode, task=args.task) filenames = get_image_filenames(path=args.input)