diff --git a/README.md b/README.md index f2c70aa53c..3a10027d5e 100644 --- a/README.md +++ b/README.md @@ -114,6 +114,7 @@ where the currently available models are: - [PatchCore](src/anomalib/models/patchcore) - [Reverse Distillation](src/anomalib/models/reverse_distillation) - [STFPM](src/anomalib/models/stfpm) +- [UFlow](src/anomalib/models/uflow) ## Feature extraction & (pre-trained) backbones diff --git a/docs/source/images/uflow/diagram.png b/docs/source/images/uflow/diagram.png new file mode 100644 index 0000000000..a824d2129a Binary files /dev/null and b/docs/source/images/uflow/diagram.png differ diff --git a/docs/source/images/uflow/iou.png b/docs/source/images/uflow/iou.png new file mode 100644 index 0000000000..53acb80d43 Binary files /dev/null and b/docs/source/images/uflow/iou.png differ diff --git a/docs/source/images/uflow/more-results.png b/docs/source/images/uflow/more-results.png new file mode 100644 index 0000000000..f16d78fa97 Binary files /dev/null and b/docs/source/images/uflow/more-results.png differ diff --git a/docs/source/images/uflow/pixel-aupro.png b/docs/source/images/uflow/pixel-aupro.png new file mode 100644 index 0000000000..0fb28d07ef Binary files /dev/null and b/docs/source/images/uflow/pixel-aupro.png differ diff --git a/docs/source/images/uflow/pixel-auroc.png b/docs/source/images/uflow/pixel-auroc.png new file mode 100644 index 0000000000..6a6247853a Binary files /dev/null and b/docs/source/images/uflow/pixel-auroc.png differ diff --git a/docs/source/images/uflow/results-mvtec-anomalies.jpg b/docs/source/images/uflow/results-mvtec-anomalies.jpg new file mode 100644 index 0000000000..fcc1e39787 Binary files /dev/null and b/docs/source/images/uflow/results-mvtec-anomalies.jpg differ diff --git a/docs/source/images/uflow/results-mvtec-good.jpg b/docs/source/images/uflow/results-mvtec-good.jpg new file mode 100644 index 0000000000..b2db925942 Binary files /dev/null and b/docs/source/images/uflow/results-mvtec-good.jpg differ diff --git a/docs/source/images/uflow/results-others-anomalies.jpg b/docs/source/images/uflow/results-others-anomalies.jpg new file mode 100644 index 0000000000..cebe9720da Binary files /dev/null and b/docs/source/images/uflow/results-others-anomalies.jpg differ diff --git a/docs/source/images/uflow/results-others-good.jpg b/docs/source/images/uflow/results-others-good.jpg new file mode 100644 index 0000000000..e31a0020c3 Binary files /dev/null and b/docs/source/images/uflow/results-others-good.jpg differ diff --git a/docs/source/images/uflow/teaser.jpg b/docs/source/images/uflow/teaser.jpg new file mode 100644 index 0000000000..973c4f2ced Binary files /dev/null and b/docs/source/images/uflow/teaser.jpg differ diff --git a/docs/source/reference_guide/algorithms/index.rst b/docs/source/reference_guide/algorithms/index.rst index 47287ff254..08b3f430e2 100644 --- a/docs/source/reference_guide/algorithms/index.rst +++ b/docs/source/reference_guide/algorithms/index.rst @@ -18,6 +18,7 @@ Algorithms patchcore reverse_distillation stfpm + uflow Feature extraction & (pre-trained) backbones diff --git a/docs/source/reference_guide/algorithms/uflow.rst b/docs/source/reference_guide/algorithms/uflow.rst new file mode 100644 index 0000000000..659f55446c --- /dev/null +++ b/docs/source/reference_guide/algorithms/uflow.rst @@ -0,0 +1,44 @@ +U-Flow +--------- + +This is the implementation of the `U-Flow `_ paper. + +Model Type: Segmentation + +Description +*********** + +U-Flow is a U-Shaped normalizing flow-based probability distribution estimator. +The method consists of three phases. +(1) Multi-scale feature extraction: a rich multi-scale representation is obtained with MSCaiT, by combining pre-trained image Transformers acting at different image scales. It can also be used any other feature extractor, such as ResNet. +(2) U-shaped Normalizing Flow: by adapting the widely used U-like architecture to NFs, a fully invertible architecture is designed. This architecture is capable of merging the information from different scales while ensuring independence both intra- and inter-scales. To make it fully invertible, split and invertible up-sampling operations are used. +(3) Anomaly score and segmentation computation: besides generating the anomaly map based on the likelihood of test data, we also propose to adapt the a contrario framework to obtain an automatic threshold by controlling the allowed number of false alarms. + +Architecture +************ + +.. image:: ../../images/uflow/diagram.png + :alt: U-Flow Architecture + +Usage +***** + +.. code-block:: bash + + $ python tools/train.py --model uflow + + +.. automodule:: anomalib.models.uflow.torch_model + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: anomalib.models.uflow.lightning_model + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: anomalib.models.uflow.anomaly_map + :members: + :undoc-members: + :show-inheritance: diff --git a/src/anomalib/models/__init__.py b/src/anomalib/models/__init__.py index a613e18a78..ee05d4d0a9 100644 --- a/src/anomalib/models/__init__.py +++ b/src/anomalib/models/__init__.py @@ -30,6 +30,7 @@ from anomalib.models.reverse_distillation import ReverseDistillation from anomalib.models.rkde import Rkde from anomalib.models.stfpm import Stfpm +from anomalib.models.uflow import Uflow __all__ = [ "Cfa", @@ -46,6 +47,7 @@ "ReverseDistillation", "Rkde", "Stfpm", + "Uflow", "AiVad", "EfficientAd", ] diff --git a/src/anomalib/models/uflow/README.md b/src/anomalib/models/uflow/README.md new file mode 100644 index 0000000000..1cafe529e3 --- /dev/null +++ b/src/anomalib/models/uflow/README.md @@ -0,0 +1,128 @@ +# U-Flow: A U-shaped Normalizing Flow for Anomaly Detection with Unsupervised Threshold + +[//]: # "This is the implementation of the [U-Flow](https://arxiv.org/abs/2211.12353) paper, based on the [original code](https://www.github.com/mtailanian/uflow)" + +This is the implementation of the [U-Flow](https://www.researchsquare.com/article/rs-3367286/latest) paper, based on the [original code](https://www.github.com/mtailanian/uflow) + +![U-Flow Architecture](https://raw.githubusercontent.com/openvinotoolkit/anomalib/main/docs/source/images/uflow/diagram.png "U-Flow Architecture") + +## Abstract + +_In this work we propose a one-class self-supervised method for anomaly segmentation in images, that benefits both from a modern machine learning approach and a more classic statistical detection theory. +The method consists of three phases. First, features are extracted using a multi-scale image Transformer architecture. Then, these features are fed into a U-shaped Normalizing Flow that lays the theoretical foundations for the last phase, which computes a pixel-level anomaly map and performs a segmentation based on the a contrario framework. +This multiple-hypothesis testing strategy permits the derivation of robust automatic detection thresholds, which are crucial in real-world applications where an operational point is needed. +The segmentation results are evaluated using the Intersection over Union (IoU) metric, and for assessing the generated anomaly maps we report the area under the Receiver Operating Characteristic curve (AUROC), and the area under the per-region-overlap curve (AUPRO). +Extensive experimentation in various datasets shows that the proposed approach produces state-of-the-art results for all metrics and all datasets, ranking first in most MvTec-AD categories, with a mean pixel-level AUROC of 98.74%._ + +![Teaser image](https://raw.githubusercontent.com/openvinotoolkit/anomalib/main/docs/source/images/uflow/teaser.png) + +## Localization results + +### Pixel AUROC over MVTec-AD Dataset + +![Pixel-AUROC results](https://raw.githubusercontent.com/openvinotoolkit/anomalib/main/docs/source/images/uflow/pixel-auroc.png "Pixel-AUROC results") + +### Pixel AUPRO over MVTec-AD Dataset + +![Pixel-AUPRO results](https://raw.githubusercontent.com/openvinotoolkit/anomalib/main/docs/source/images/uflow/pixel-aupro.png "Pixel-AUPRO results") + +## Segmentation results (IoU) with threshold log(NFA)=0 + +This paper also proposes a method to automatically compute the threshold using the a contrario framework. All results below are obtained with the threshold log(NFA)=0. +In the default code here, for the sake of comparison with all the other methods of the library, the segmentation is done computing the threshold over the anomaly map at train time. +Nevertheless, the code for computing the segmentation mask with the NFA criterion is included in the `src/anomalib/models/uflow/anomaly_map.py`. + +![IoU results](https://raw.githubusercontent.com/openvinotoolkit/anomalib/main/docs/source/images/uflow/iou.png "IoU results") + +## Results over other datasets + +![Results over other datasets](https://raw.githubusercontent.com/openvinotoolkit/anomalib/main/docs/source/images/uflow/more-results.png "Results over other datasets") + +## Benchmarking + +Note that the proposed method uses the MCait Feature Extractor, which has an input size of 448x448. In the benchmarking, a size of 256x256 is used for all methods, and therefore the results may differ from those reported. In order to exactly reproduce all results, the reader can refer to the original code (see [here](https://www.github.com/mtailanian/uflow), where the configs used and even the trained checkpoints can be downloaded from [this release](https://github.com/mtailanian/uflow/releases/tag/trained-mvtec-models). + +## Reproducing paper's results + +Using the default parameters of the config file (`src/anomalib/models/uflow/config.yaml`), the results obtained are very close to the ones reported in the paper: + +bottle: 97.98, cable: 98.17, capsule: 98.95, carpet: 99.45, grid: 98.19, hazelnut: 99.01, leather: 99.41, metal_nut: 98.19, pill: 99.15, screw: 99.25, tile: 96.93, toothbrush: 98.97, transistor: 96.70, wood: 96.87, zipper: 97.92 + +In order to obtain the same exact results, although the architecture parameters stays always the same, the following values for the learning rate and batch size should be used (please refer to the [original code](https://www.github.com/mtailanian/uflow) for more details, where the used configs are available in the source code ([here](https://github.com/mtailanian/uflow/tree/main/configs)), and trained checkpoints are available in [this release](https://github.com/mtailanian/uflow/releases/tag/trained-mvtec-models)): + +## Usage + +`python tools/train.py --model uflow` + +## Download data + +### MVTec + +https://www.mvtec.com/company/research/datasets/mvtec-ad + +### Bean Tech + +https://paperswithcode.com/dataset/btad + +### LGG MRI + +https://www.kaggle.com/datasets/mateuszbuda/lgg-mri-segmentation + +### ShanghaiTech Campus + +https://svip-lab.github.io/dataset/campus_dataset.html + +## [Optional] Download pre-trained models + +Pre-trained models can be found in [this release](https://github.com/mtailanian/uflow/tree/main/configs), or can be downloaded from [Google Drive](https://drive.google.com/drive/folders/1W1rE0mu4Lv3uWHA5GZigmvVNlBVHqTv_?usp=sharing) + +For an easier way of downloading them, please refer to the `README.md` from the [original code](https://www.github.com/mtailanian/uflow) + +For reproducing the exact results from the paper, different learning rates and batch sizes are to be used for each category. You can find the exact values in the `configs` folder, following the [previous link](https://drive.google.com/drive/folders/1W1rE0mu4Lv3uWHA5GZigmvVNlBVHqTv_?usp=sharing). + +## A note on sizes at different points + +Input + +```text +- Scale 1: [3, 448, 448] +- Scale 2: [3, 224, 224] +``` + +MS-Cait outputs + +```text +- Scale 1: [768, 28, 28] +- Scale 2: [384, 14, 14] +``` + +Normalizing Flow outputs + +```text +- Scale 1: [816, 28, 28] --> 816 = 768 + 384 / 2 / 4 +- Scale 2: [192, 14, 14] --> 192 = 384 / 2 +``` + +`/ 2` corresponds to the split, and `/ 4` to the invertible upsample. + +## Example results + +### Anomalies + +#### MVTec + +![MVTec results - anomalies](https://raw.githubusercontent.com/openvinotoolkit/anomalib/main/docs/source/images/uflow/results-mvtec-anomalies.jpg "MVTec results - anomalies") + +#### BeanTech, LGG MRI, STC + +![BeanTech, LGG MRI, STC results - anomalies](https://raw.githubusercontent.com/openvinotoolkit/anomalib/main/docs/source/images/uflow/results-others-anomalies.jpg "BeanTech, LGG MRI, STC results - anomalies") + +### Normal images + +#### MVTec + +![MVTec results - normal](https://raw.githubusercontent.com/openvinotoolkit/anomalib/main/docs/source/images/uflow/results-mvtec-good.jpg "MVTec results - normal") + +#### BeanTech, LGG MRI, STC + +![BeanTech, LGG MRI, STC results - normal](https://raw.githubusercontent.com/openvinotoolkit/anomalib/main/docs/source/images/uflow/results-others-good.jpg "BeanTech, LGG MRI, STC results - normal") diff --git a/src/anomalib/models/uflow/__init__.py b/src/anomalib/models/uflow/__init__.py new file mode 100644 index 0000000000..df5786900d --- /dev/null +++ b/src/anomalib/models/uflow/__init__.py @@ -0,0 +1,8 @@ +"""U-Flow: A U-shaped Normalizing Flow for Anomaly Detection with Unsupervised Threshold.""" + +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .lightning_model import Uflow, UflowLightning + +__all__ = ["Uflow", "UflowLightning"] diff --git a/src/anomalib/models/uflow/anomaly_map.py b/src/anomalib/models/uflow/anomaly_map.py new file mode 100644 index 0000000000..9cc6586505 --- /dev/null +++ b/src/anomalib/models/uflow/anomaly_map.py @@ -0,0 +1,166 @@ +"""UFlow Anomaly Map Generator Implementation.""" + +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from typing import List + +import numpy as np +import scipy.stats as st +import torch +import torch.nn.functional as F +from mpmath import binomial, mp +from omegaconf import ListConfig +from scipy import integrate +from torch import Tensor, nn + +mp.dps = 15 # Set precision for NFA computation (in case of high_precision=True) + + +class AnomalyMapGenerator(nn.Module): + """Generate Anomaly Heatmap and segmentation.""" + + def __init__(self, input_size: ListConfig | tuple) -> None: + super().__init__() + self.input_size = input_size if isinstance(input_size, tuple) else tuple(input_size) + + def forward(self, latent_variables: list[Tensor]) -> Tensor: + return self.compute_anomaly_map(latent_variables) + + def compute_anomaly_map(self, latent_variables: list[Tensor]) -> Tensor: + """ + Generate a likelihood-based anomaly map, from latent variables. + Args: + latent_variables: List of latent variables from the UFlow model. Each element is a tensor of shape + (N, Cl, Hl, Wl), where N is the batch size, Cl is the number of channels, and Hl and Wl are the height and + width of the latent variables, respectively, for each scale l. + + Returns: + Final Anomaly Map. Tensor of shape (N, 1, H, W), where N is the batch size, and H and W are the height and + width of the input image, respectively. + """ + + likelihoods = [] + for z in latent_variables: + # Mean prob by scale. Likelihood is actually with sum instead of mean. Using mean to avoid numerical issues. + # Also, this way all scales have the same weight, and it does not depend on the number of channels + log_prob_i = -torch.mean(z**2, dim=1, keepdim=True) * 0.5 + prob_i = torch.exp(log_prob_i) + likelihoods.append( + F.interpolate( + prob_i, + size=self.input_size, + mode="bilinear", + align_corners=False, + ) + ) + anomaly_map = 1 - torch.mean(torch.stack(likelihoods, dim=-1), dim=-1) + return anomaly_map + + def compute_anomaly_mask( + self, + z: List[torch.Tensor], + win_size: int = 7, + binomial_probability_thr: float = 0.5, + high_precision: bool = False, + ): + """ + This method is not used in the basic functionality of training and testing. It is a bit slow, so we decided to + leave it as an option for the user. It is included as it is part of the U-Flow paper, and can be called + separately if an unsupervised anomaly segmentation is needed. + + Generate an anomaly mask, from latent variables. It is based on the NFA (Number of False Alarms) method, which + is a statistical method to detect anomalies. The NFA is computed as the log of the probability of the null + hypothesis, which is that all pixels are normal. First, we compute a list of candidate pixels, with + suspiciously high values of z^2, by applying a binomial test to each pixel, looking at a window around it. + Then, to compute the NFA values (actually the log-NFA), we evaluate how probable is that a pixel belongs to the + normal distribution. The null-hypothesis is that under normality assumptions, all candidate pixels are uniformly + distributed. Then, the detection is based on the concentration of candidate pixels. + + Args: + z: List of latent variables from the UFlow model. Each element is a tensor of shape + (N, Cl, Hl, Wl), where N is the batch size, Cl is the number of channels, and Hl and Wl are the height and + width of the latent variables, respectively, for each scale l. + win_size: Window size for the binomial test. + binomial_probability_thr: Probability threshold for the binomial test. + high_precision: Whether to use high precision for the binomial test. + + Returns: + Anomaly mask. Tensor of shape (N, 1, H, W), where N is the batch size, and H and W are the height and + width of the input image, respectively. + """ + log_prob_l = [ + self.binomial_test(zi, win_size / (2**scale), binomial_probability_thr, high_precision) + for scale, zi in enumerate(z) + ] + + log_prob_l_up = torch.cat( + [F.interpolate(lpl, size=self.input_size, mode="bicubic", align_corners=True) for lpl in log_prob_l], dim=1 + ) + + log_prob = torch.sum(log_prob_l_up, dim=1, keepdim=True) + + log_number_of_tests = torch.log10(torch.sum(torch.tensor([zi.shape[-2] * zi.shape[-1] for zi in z]))) + log_nfa = log_number_of_tests + log_prob + + anomaly_score = -log_nfa + anomaly_mask = anomaly_score < 0 + + return anomaly_mask + + @staticmethod + def binomial_test(z: torch.Tensor, win, probability_thr: float, high_precision: bool = False) -> torch.Tensor: + """ + The binomial test applied to validate or reject the null hypothesis that the pixel is normal. The null + hypothesis is that the pixel is normal, and the alternative hypothesis is that the pixel is anomalous. The + binomial test is applied to a window around the pixel, and the number of pixels in the window that are + anomalous is compared to the number of pixels that are expected to be anomalous under the null hypothesis. + Args: + z: Latent variable from the UFlow model. Tensor of shape (N, Cl, Hl, Wl), where N is the batch size, Cl is + the number of channels, and Hl and Wl are the height and width of the latent variables, respectively. + win: Window size for the binomial test. + probability_thr: Probability threshold for the binomial test. + high_precision: Whether to use high precision for the binomial test. + + Returns: + Log of the probability of the null hypothesis. + + """ + tau = st.chi2.ppf(probability_thr, 1) + half_win = np.max([int(win // 2), 1]) + + n_chann = z.shape[1] + + # Candidates + z2 = F.pad(z**2, tuple(4 * [half_win]), "reflect").detach().cpu() + z2_unfold_h = z2.unfold(-2, 2 * half_win + 1, 1) + z2_unfold_hw = z2_unfold_h.unfold(-2, 2 * half_win + 1, 1).numpy() + observed_candidates_k = np.sum(z2_unfold_hw >= tau, axis=(-2, -1)) + + # All volume together + observed_candidates = np.sum(observed_candidates_k, axis=1, keepdims=True) + x = observed_candidates / n_chann + n = int((2 * half_win + 1) ** 2) + + # Low precision + if not high_precision: + log_prob = torch.tensor(st.binom.logsf(x, n, 1 - probability_thr) / np.log(10)) + # High precision - good and slow + else: + to_mp = np.frompyfunc(mp.mpf, 1, 1) + mpn = mp.mpf(n) + mpp = probability_thr + + def binomial_density(k): + return binomial(mpn, to_mp(k)) * (1 - mpp) ** k * mpp ** (mpn - k) + + def integral(xx): + return integrate.quad(binomial_density, xx, n)[0] + + integral_array = np.vectorize(integral) + prob = integral_array(x) + log_prob = torch.tensor(np.log10(prob)) + + return log_prob diff --git a/src/anomalib/models/uflow/config.yaml b/src/anomalib/models/uflow/config.yaml new file mode 100644 index 0000000000..5d1d2157d3 --- /dev/null +++ b/src/anomalib/models/uflow/config.yaml @@ -0,0 +1,108 @@ +dataset: + name: mvtec + format: mvtec + path: ./datasets/MVTec + category: bottle + task: segmentation + train_batch_size: 14 # values used in paper: bottle: 23, cable: 14, capsule: 14, carpet: 13, grid: 12, hazelnut: 21, leather: 15, metal_nut: 11, pill: 11, screw: 11, tile: 30, toothbrush: 21, transistor: 20, wood: 22, zipper: 24 + eval_batch_size: 16 + inference_batch_size: 16 + num_workers: 8 + image_size: 448 # 448 for mcait or 256 for ResNet extractors # dimensions to which images are resized (mandatory) + center_crop: null # dimensions to which images are center-cropped after resizing (optional) + normalization: imagenet # data distribution to which the images will be normalized: [none, imagenet] + transform_config: + train: null + eval: null + test_split_mode: from_dir # options: [from_dir, synthetic] + test_split_ratio: 0.2 # fraction of train images held out testing (usage depends on test_split_mode) + val_split_mode: same_as_test # options: [same_as_test, from_test, synthetic] + val_split_ratio: 0.5 # fraction of train/test images held out for validation (usage depends on val_split_mode) + +model: + name: uflow + flow_steps: 4 + permute_soft: false + affine_clamp: 2.0 + affine_subnet_channels_ratio: 1.0 + backbone: mcait # official: mcait, other extractors tested: resnet18, wide_resnet50_2. Could use others... + lr: 1e-3 # values used in paper: bottle: 0.0001128999, cable: 0.0016160391, capsule: 0.0012118892, carpet: 0.0012118892, grid: 0.0000362248, hazelnut: 0.0013268899, leather: 0.0006124724, metal_nut: 0.0008148858, pill: 0.0010756100, screw: 0.0004155987, tile: 0.0060457548, toothbrush: 0.0001287313, transistor: 0.0011212904, wood: 0.0002466546, zipper: 0.0000455247 + weight_decay: 1e-5 + early_stopping: + patience: 20 + metric: pixel_AUROC + mode: max + normalization_method: min_max # options: [null, min_max, cdf] + +metrics: + image: + - F1Score + - AUROC + pixel: + - F1Score + - AUROC + threshold: # TODO: add NFA + method: adaptive # options: [adaptive, manual] + manual_image: null + manual_pixel: null + +visualization: + show_images: False # show images on the screen + save_images: True # save images to the file system + log_images: True # log images to the available loggers (if any) + image_save_path: null # path to which images will be saved + mode: full # options: ["full", "simple"] + +project: + seed: 0 + path: ./results + +logging: + logger: [] # options: [comet, tensorboard, wandb, csv] or combinations. + log_graph: false # Logs the model graph to respective logger. + +optimization: + export_mode: null # options: torch, onnx, openvino + +# PL Trainer Args. Don't add extra parameter here. +trainer: + enable_checkpointing: true + default_root_dir: null + gradient_clip_val: 0 + gradient_clip_algorithm: norm + num_nodes: 1 + devices: 1 + enable_progress_bar: true + overfit_batches: 0.0 + track_grad_norm: -1 + check_val_every_n_epoch: 1 # Don't validate before extracting features. + fast_dev_run: false + accumulate_grad_batches: 1 + max_epochs: 200 + min_epochs: null + max_steps: -1 + min_steps: null + max_time: null + limit_train_batches: 1.0 + limit_val_batches: 1.0 + limit_test_batches: 1.0 + limit_predict_batches: 1.0 + val_check_interval: 1.0 # Don't validate before extracting features. + log_every_n_steps: 50 + accelerator: auto # <"cpu", "gpu", "tpu", "ipu", "hpu", "auto"> + strategy: null + sync_batchnorm: false + precision: 32 + enable_model_summary: true + num_sanity_val_steps: 0 + profiler: null + benchmark: false + deterministic: false + reload_dataloaders_every_n_epochs: 0 + auto_lr_find: false + replace_sampler_ddp: true + detect_anomaly: false + auto_scale_batch_size: false + plugins: null + move_metrics_to_cpu: false + multiple_trainloader_mode: max_size_cycle diff --git a/src/anomalib/models/uflow/feature_extraction.py b/src/anomalib/models/uflow/feature_extraction.py new file mode 100644 index 0000000000..b66f805df0 --- /dev/null +++ b/src/anomalib/models/uflow/feature_extraction.py @@ -0,0 +1,126 @@ +from typing import Tuple + +import timm +import torch +import torch.nn.functional as F +from torch import nn + +from anomalib.models.components.feature_extractors import TimmFeatureExtractor + +AVAILABLE_EXTRACTORS = ["mcait", "resnet18", "wide_resnet50_2"] + + +def get_feature_extractor(backbone, input_size: Tuple[int, int] = (256, 256)): + """ + Get feature extractor. Currently, is restricted to AVAILABLE_EXTRACTORS. + Args: + backbone (str): Backbone name. + input_size (tuple[int, int]): Input size. + + Returns: + FeatureExtractorInterface: Feature extractor. + """ + assert backbone in AVAILABLE_EXTRACTORS, f"Feature extractor must be one of {AVAILABLE_EXTRACTORS}." + if backbone in ["resnet18", "wide_resnet50_2"]: + return FeatureExtractor(backbone, input_size, layers=["layer1", "layer2", "layer3"]) + elif backbone == "mcait": + return MCaitFeatureExtractor() + raise ValueError( + "`backbone` must be one of `[mcait, resnet18, wide_resnet50_2]`. These are the only feature extractors tested. " + "It does not mean that other feature extractors will not work." + ) + + +class FeatureExtractor(TimmFeatureExtractor): + """Feature extractor based on ResNet (or others) backbones.""" + + def __init__(self, backbone, input_size, layers=("layer1", "layer2", "layer3"), **kwargs): + super(FeatureExtractor, self).__init__(backbone, layers, pre_trained=True, requires_grad=False) + self.channels = self.feature_extractor.feature_info.channels() + self.scale_factors = self.feature_extractor.feature_info.reduction() + self.scales = range(len(self.scale_factors)) + + self.feature_normalizations = nn.ModuleList() + for in_channels, scale in zip(self.channels, self.scale_factors): + self.feature_normalizations.append( + nn.LayerNorm( + [in_channels, int(input_size[0] / scale), int(input_size[1] / scale)], elementwise_affine=True + ) + ) + + for param in self.feature_extractor.parameters(): + param.requires_grad = False + + def forward(self, img, **kwargs): + features = self.extract_features(img) + normalized_features = self.normalize_features(features, **kwargs) + return normalized_features + + def extract_features(self, img, **kwargs): + self.feature_extractor.eval() + return self.feature_extractor(img) + + def normalize_features(self, features, **kwargs): + return [self.feature_normalizations[i](feature) for i, feature in enumerate(features)] + + +class MCaitFeatureExtractor(nn.Module): + """ + Feature extractor based on MCait backbone. This is the proposed feature extractor in the paper. It uses two + independently trained Cait models, at different scales, with input sizes 448 and 224, respectively. + It also includes a normalization layer for each scale. + """ + + def __init__(self): + super(MCaitFeatureExtractor, self).__init__() + self.input_size = 448 + self.extractor1 = timm.create_model("cait_m48_448", pretrained=True) + self.extractor2 = timm.create_model("cait_s24_224", pretrained=True) + self.channels = [768, 384] + self.scale_factors = [16, 32] + self.scales = range(len(self.scale_factors)) + + for param in self.extractor1.parameters(): + param.requires_grad = False + for param in self.extractor2.parameters(): + param.requires_grad = False + + def forward(self, img, training=True): + features = self.extract_features(img) + normalized_features = self.normalize_features(features, training=training) + return normalized_features + + def extract_features(self, img, **kwargs): + self.extractor1.eval() + self.extractor2.eval() + + # Scale 1 --> Extractor 1 + x1 = self.extractor1.patch_embed(img) + x1 = x1 + self.extractor1.pos_embed + x1 = self.extractor1.pos_drop(x1) + for i in range(41): # paper Table 6. Block Index = 40 + x1 = self.extractor1.blocks[i](x1) + + # Scale 2 --> Extractor 2 + img_sub = F.interpolate(torch.Tensor(img), size=(224, 224), mode="bicubic", align_corners=True) + x2 = self.extractor2.patch_embed(img_sub) + x2 = x2 + self.extractor2.pos_embed + x2 = self.extractor2.pos_drop(x2) + for i in range(21): + x2 = self.extractor2.blocks[i](x2) + + features = [x1, x2] + return features + + def normalize_features(self, features, **kwargs): + normalized_features = [] + for i, extractor in enumerate([self.extractor1, self.extractor2]): + batch, _, channels = features[i].shape + scale_factor = self.scale_factors[i] + + x = extractor.norm(features[i].contiguous()) + x = x.permute(0, 2, 1) + x = x.reshape(batch, channels, self.input_size // scale_factor, self.input_size // scale_factor) + normalized_features.append(x) + + return normalized_features diff --git a/src/anomalib/models/uflow/lightning_model.py b/src/anomalib/models/uflow/lightning_model.py new file mode 100644 index 0000000000..b3636bae28 --- /dev/null +++ b/src/anomalib/models/uflow/lightning_model.py @@ -0,0 +1,118 @@ +"""U-Flow: A U-shaped Normalizing Flow for Anomaly Detection with Unsupervised Threshold. + +https://arxiv.org/pdf/2211.12353.pdf +""" + +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import torch +from omegaconf import DictConfig, ListConfig +from pytorch_lightning.callbacks import EarlyStopping +from pytorch_lightning.utilities.types import STEP_OUTPUT +from torch import Tensor + +from anomalib.models.components import AnomalyModule +from anomalib.models.uflow.loss import UFlowLoss +from anomalib.models.uflow.torch_model import UflowModel + +__all__ = ["Uflow", "UflowLightning"] + + +class Uflow(AnomalyModule): + """PL Lightning Module for the UFLOW algorithm.""" + + def __init__( + self, + input_size: tuple[int, int], + backbone: str, + flow_steps: int = 4, + affine_clamp: float = 2.0, + affine_subnet_channels_ratio: float = 1.0, + permute_soft: bool = False, + ) -> None: + """ + Args: + input_size (tuple[int, int]): Input image size. + backbone (str): Backbone name. + flow_steps (int): Number of flow steps. + affine_clamp (float): Affine clamp. + affine_subnet_channels_ratio (float): Affine subnet channels ratio. + permute_soft (bool): Whether to use soft permutation. + """ + super().__init__() + self.model: UflowModel = UflowModel( + input_size=input_size, + backbone=backbone, + flow_steps=flow_steps, + affine_clamp=affine_clamp, + affine_subnet_channels_ratio=affine_subnet_channels_ratio, + permute_soft=permute_soft, + ) + self.loss = UFlowLoss() + + def training_step(self, batch: dict[str, str | Tensor], *args, **kwargs) -> STEP_OUTPUT: + z, ljd = self.model(batch["image"]) + loss = self.loss(z, ljd) + self.log_dict({"loss": loss}, on_step=True, on_epoch=False, prog_bar=False, logger=True) + return {"loss": loss} + + def validation_step(self, batch: dict[str, str | Tensor], *args, **kwargs) -> STEP_OUTPUT: + anomaly_maps = self.model(batch["image"]) + batch["anomaly_maps"] = anomaly_maps + return batch + + +class UflowLightning(Uflow): + """PL Lightning Module for the UFLOW algorithm. + + Args: + hparams (DictConfig | ListConfig): Model params + """ + + def __init__(self, hparams: DictConfig | ListConfig) -> None: + super().__init__( + input_size=hparams.model.input_size, + backbone=hparams.model.backbone, + flow_steps=hparams.model.flow_steps, + affine_clamp=hparams.model.affine_clamp, + affine_subnet_channels_ratio=hparams.model.affine_subnet_channels_ratio, + permute_soft=hparams.model.permute_soft, + ) + self.lr = hparams.model.lr + self.weight_decay = hparams.model.weight_decay + self.hparams: DictConfig | ListConfig # type: ignore + self.save_hyperparameters(hparams) + + def configure_callbacks(self) -> list[EarlyStopping]: + """Configure model-specific callbacks. + + Note: + This method is used for the existing CLI. + When PL CLI is introduced, configure callback method will be + deprecated, and callbacks will be configured from either + config.yaml file or from CLI. + """ + early_stopping = EarlyStopping( + monitor=self.hparams.model.early_stopping.metric, + patience=self.hparams.model.early_stopping.patience, + mode=self.hparams.model.early_stopping.mode, + ) + return [early_stopping] + + def configure_optimizers(self): + def get_total_number_of_iterations(): + return 25000 + + # Optimizer + optimizer = torch.optim.Adam( + [{"params": self.parameters(), "initial_lr": self.lr}], lr=self.lr, weight_decay=self.weight_decay + ) + + # Scheduler for slowly reducing learning rate + scheduler = torch.optim.lr_scheduler.LinearLR( + optimizer, start_factor=1.0, end_factor=0.4, total_iters=get_total_number_of_iterations() + ) + return [optimizer], [scheduler] diff --git a/src/anomalib/models/uflow/loss.py b/src/anomalib/models/uflow/loss.py new file mode 100644 index 0000000000..3a1686371a --- /dev/null +++ b/src/anomalib/models/uflow/loss.py @@ -0,0 +1,27 @@ +"""Loss function for the UFlow Model Implementation.""" + +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import torch +from torch import Tensor, nn + + +class UFlowLoss(nn.Module): + """UFlow Loss.""" + + def forward(self, hidden_variables: list[Tensor], jacobians: list[Tensor]) -> Tensor: + """Calculate the UFlow loss. + + Args: + hidden_variables (list[Tensor]): Hidden variables from the fastflow model. f: X -> Z + jacobians (list[Tensor]): Log of the jacobian determinants from the fastflow model. + + Returns: + Tensor: UFlow loss computed based on the hidden variables and the log of the Jacobians. + """ + lpz = torch.sum(torch.stack([0.5 * torch.sum(z_i**2, dim=(1, 2, 3)) for z_i in hidden_variables], dim=0)) + flow_loss = torch.mean(lpz - jacobians) + return flow_loss diff --git a/src/anomalib/models/uflow/torch_model.py b/src/anomalib/models/uflow/torch_model.py new file mode 100644 index 0000000000..7989356f4c --- /dev/null +++ b/src/anomalib/models/uflow/torch_model.py @@ -0,0 +1,172 @@ +import torch.nn as nn +from FrEIA import framework as ff +from FrEIA import modules as fm + +from anomalib.models.components.flow import AllInOneBlock + +from .anomaly_map import AnomalyMapGenerator +from .feature_extraction import get_feature_extractor + + +class AffineCouplingSubnet: + """ + Class for building the Affine Coupling subnet. + It is passed as an argument to the `AllInOneBlock` module. + """ + + def __init__(self, kernel_size: int, subnet_channels_ratio: float): + """ + Args: + kernel_size (int): Kernel size. + subnet_channels_ratio (float): Subnet channels ratio. + """ + self.kernel_size = kernel_size + self.subnet_channels_ratio = subnet_channels_ratio + + def __call__(self, in_channels: int, out_channels: int) -> nn.Sequential: + """ + Args: + in_channels (int): Input channels. + out_channels (int): Output channels. + + Returns: + nn.Sequential: Affine Coupling subnet. + """ + mid_channels = int(in_channels * self.subnet_channels_ratio) + return nn.Sequential( + nn.Conv2d(in_channels, mid_channels, self.kernel_size, padding="same"), + nn.ReLU(), + nn.Conv2d(mid_channels, out_channels, self.kernel_size, padding="same"), + ) + + +class UflowModel(nn.Module): + def __init__( + self, + input_size: tuple[int, int] = (448, 448), + flow_steps: int = 4, + backbone: str = "mcait", + affine_clamp: float = 2.0, + affine_subnet_channels_ratio: float = 1.0, + permute_soft: bool = False, + ) -> None: + """ + Args: + input_size (tuple[int, int]): Input image size. + flow_steps (int): Number of flow steps. + backbone (str): Backbone name. + affine_clamp (float): Affine clamp. + affine_subnet_channels_ratio (float): Affine subnet channels ratio. + permute_soft (bool): Whether to use soft permutation. + """ + super().__init__() + + self.input_size = input_size + self.affine_clamp = affine_clamp + self.affine_subnet_channels_ratio = affine_subnet_channels_ratio + self.permute_soft = permute_soft + + self.feature_extractor = get_feature_extractor(backbone, input_size) + self.flow = self.build_flow(flow_steps) + self.anomaly_map_generator = AnomalyMapGenerator(input_size) + + def build_flow(self, flow_steps: int) -> ff.GraphINN: + """ + Build the flow model. + First we start with the input nodes, which have to match the feature extractor output. + Then, we build the U-Shaped flow. Starting from the bottom (the coarsest scale), the flow is built as follows: + 1. Pass the input through a Flow Stage (`build_flow_stage`). + 2. Split the output of the flow stage into two parts, one that goes directly to the output, + 3. and the other is up-sampled, and will be concatenated with the output of the next flow stage (next scale) + 4. Repeat steps 1-3 for the next scale. + Finally, we build the Flow graph using the input nodes, the flow stages, and the output nodes. + + Args: + flow_steps (int): Number of flow steps. + + Returns: + ff.GraphINN: Flow model. + """ + input_nodes = [] + for channel, s_factor in zip(self.feature_extractor.channels, self.feature_extractor.scale_factors): + input_nodes.append( + ff.InputNode( + channel, self.input_size[0] // s_factor, self.input_size[1] // s_factor, name=f"cond_{channel}" + ) + ) + + nodes, output_nodes = [], [] + last_node = input_nodes[-1] + for i in reversed(range(1, len(input_nodes))): + flows = self.build_flow_stage(last_node, flow_steps) + volume_size = flows[-1].output_dims[0][0] + split = ff.Node( + flows[-1], + fm.Split, + {"section_sizes": (volume_size // 8 * 4, volume_size - volume_size // 8 * 4), "dim": 0}, + name=f"split_{i + 1}", + ) + output = ff.OutputNode(split.out1, name=f"output_scale_{i + 1}") + up = ff.Node(split.out0, fm.IRevNetUpsampling, {}, name=f"up_{i + 1}") + last_node = ff.Node([input_nodes[i - 1].out0, up.out0], fm.Concat, {"dim": 0}, name=f"cat_{i}") + + output_nodes.append(output) + nodes.extend([*flows, split, up, last_node]) + + flows = self.build_flow_stage(last_node, flow_steps) + output = ff.OutputNode(flows[-1], name="output_scale_1") + + output_nodes.append(output) + nodes.extend(flows) + + return ff.GraphINN(input_nodes + nodes + output_nodes[::-1]) + + def build_flow_stage(self, in_node: ff.Node, flow_steps: int, condition_node: ff.Node = None) -> list[ff.Node]: + """ + Build a flow stage, which is a sequence of flow steps. + Each flow stage is essentially a sequence of `flow_steps` Glow blocks (`AllInOneBlock`). + + Args: + in_node (ff.Node): Input node. + flow_steps (int): Number of flow steps. + condition_node (ff.Node): Condition node. + + Returns: + List[ff.Node]: List of flow steps. + """ + + flow_size = in_node.output_dims[0][-1] + nodes = [] + for step in range(flow_steps): + nodes.append( + ff.Node( + in_node, + AllInOneBlock, + module_args={ + "subnet_constructor": AffineCouplingSubnet( + 3 if step % 2 == 0 else 1, self.affine_subnet_channels_ratio + ), + "affine_clamping": self.affine_clamp, + "permute_soft": self.permute_soft, + }, + conditions=condition_node, + name=f"flow{flow_size}_step{step}", + ) + ) + in_node = nodes[-1] + return nodes + + def forward(self, image): + features = self.feature_extractor(image) + z, ljd = self.encode(features) + + if self.training: + return z, ljd + else: + return self.anomaly_map_generator(z) + + def encode(self, features): + z, ljd = self.flow(features, rev=False) + if len(self.feature_extractor.scales) == 1: + z = [z] + return z, ljd diff --git a/tests/pre_merge/deploy/test_inferencer.py b/tests/pre_merge/deploy/test_inferencer.py index 7474718610..ed0c6b0ddb 100644 --- a/tests/pre_merge/deploy/test_inferencer.py +++ b/tests/pre_merge/deploy/test_inferencer.py @@ -88,6 +88,7 @@ def make( ("patchcore", "segmentation"), ("reverse_distillation", "segmentation"), ("stfpm", "segmentation"), + ("uflow", "segmentation"), # also test different task types for a single model ("padim", "classification"), ("padim", "detection"), diff --git a/tests/pre_merge/models/test_model_premerge.py b/tests/pre_merge/models/test_model_premerge.py index eb6f899362..24f7e77fbb 100644 --- a/tests/pre_merge/models/test_model_premerge.py +++ b/tests/pre_merge/models/test_model_premerge.py @@ -31,6 +31,7 @@ class TestModel: ("reverse_distillation", False), ("rkde", False), ("stfpm", False), + ("uflow", False), ], ) @TestDataset(num_train=20, num_test=10)