Skip to content

Commit

Permalink
SSPCAB implementation (#500)
Browse files Browse the repository at this point in the history
* add sspcab module

* add sspcab to draem model

* fix style checks for sspcab module

* add custom sspcab implementation

* use short license header

* do not detach hook outputs

* add paper link

* channels -> in_channels

* explain global average pooling operation in comment

* typing
  • Loading branch information
djdameln authored Aug 15, 2022
1 parent 37bc8de commit b2b5ad4
Show file tree
Hide file tree
Showing 5 changed files with 151 additions and 17 deletions.
8 changes: 8 additions & 0 deletions anomalib/models/components/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Neural network layers."""

# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from .sspcab import SSPCAB

__all__ = ["SSPCAB"]
81 changes: 81 additions & 0 deletions anomalib/models/components/layers/sspcab.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""SSPCAB: Self-Supervised Predictive Convolutional Attention Block for reconstruction-based models.
Paper https://arxiv.org/abs/2111.09099
"""

# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import torch
import torch.nn.functional as F
from torch import Tensor, nn


class AttentionModule(nn.Module):
"""Squeeze and excitation block that acts as the attention module in SSPCAB.
Args:
channels (int): Number of input channels.
reduction_ratio (int): Reduction ratio of the attention module.
"""

def __init__(self, in_channels: int, reduction_ratio: int = 8):
super().__init__()

out_channels = in_channels // reduction_ratio
self.fc1 = nn.Linear(in_channels, out_channels)
self.fc2 = nn.Linear(out_channels, in_channels)

def forward(self, inputs: Tensor) -> Tensor:
"""Forward pass through the attention module."""
# reduce feature map to 1d vector through global average pooling
avg_pooled = inputs.mean(dim=(2, 3))

# squeeze and excite
act = self.fc1(avg_pooled)
act = F.relu(act)
act = self.fc2(act)
act = F.sigmoid(act)

# multiply with input
se_out = inputs * act.view(act.shape[0], act.shape[1], 1, 1)

return se_out


class SSPCAB(nn.Module):
"""SSPCAB block.
Args:
in_channels (int): Number of input channels.
kernel_size (int): Size of the receptive fields of the masked convolution kernel.
dilation (int): Dilation factor of the masked convolution kernel.
reduction_ratio (int): Reduction ratio of the attention module.
"""

def __init__(self, in_channels: int, kernel_size: int = 1, dilation: int = 1, reduction_ratio: int = 8):
super().__init__()

self.pad = kernel_size + dilation
self.crop = 2 * (kernel_size + dilation)

self.masked_conv1 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size)
self.masked_conv2 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size)
self.masked_conv3 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size)
self.masked_conv4 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size)

self.attention_module = AttentionModule(in_channels=in_channels, reduction_ratio=reduction_ratio)

def forward(self, inputs: Tensor) -> Tensor:
"""Forward pass through the SSPCAB block."""
# compute masked convolution
padded = F.pad(inputs, (self.pad,) * 4)
masked_out = torch.zeros_like(inputs)
masked_out += self.masked_conv1(padded[..., : -self.crop, : -self.crop])
masked_out += self.masked_conv2(padded[..., : -self.crop, self.crop :])
masked_out += self.masked_conv3(padded[..., self.crop :, : -self.crop])
masked_out += self.masked_conv4(padded[..., self.crop :, self.crop :])

# apply channel attention module
sspcab_out = self.attention_module(masked_out)
return sspcab_out
2 changes: 2 additions & 0 deletions anomalib/models/draem/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ model:
name: draem
anomaly_source_path: null # optional, e.g. ./datasets/dtd
lr: 0.0001
enable_sspcab: false
sspcab_lambda: 0.1
early_stopping:
patience: 20
metric: pixel_AUROC
Expand Down
46 changes: 42 additions & 4 deletions anomalib/models/draem/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from typing import Optional, Union
from typing import Callable, Dict, Optional, Union

import torch
from omegaconf import DictConfig, ListConfig
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.cli import MODEL_REGISTRY
from torch import Tensor, nn

from anomalib.models.components import AnomalyModule
from anomalib.models.draem.loss import DraemLoss
Expand All @@ -30,12 +31,40 @@ class Draem(AnomalyModule):
be used if left empty.
"""

def __init__(self, anomaly_source_path: Optional[str] = None):
def __init__(
self, enable_sspcab: bool = False, sspcab_lambda: float = 0.1, anomaly_source_path: Optional[str] = None
):
super().__init__()

self.augmenter = Augmenter(anomaly_source_path)
self.model = DraemModel()
self.model = DraemModel(sspcab=enable_sspcab)
self.loss = DraemLoss()
self.sspcab = enable_sspcab

if self.sspcab:
self.sspcab_activations: Dict = {}
self.setup_sspcab()
self.sspcab_loss = nn.MSELoss()
self.sspcab_lambda = sspcab_lambda

def setup_sspcab(self):
"""Prepare the model for the SSPCAB training step by adding forward hooks for the SSPCAB layer activations."""

def get_activation(name: str) -> Callable:
"""Retrieves the activations.
Args:
name (str): Identifier for the retrieved activations.
"""

def hook(_, __, output: Tensor):
"""Hook for retrieving the activations."""
self.sspcab_activations[name] = output

return hook

self.model.reconstructive_subnetwork.encoder.mp4.register_forward_hook(get_activation("input"))
self.model.reconstructive_subnetwork.encoder.block5.register_forward_hook(get_activation("output"))

def training_step(self, batch, _): # pylint: disable=arguments-differ
"""Training Step of DRAEM.
Expand All @@ -56,6 +85,11 @@ def training_step(self, batch, _): # pylint: disable=arguments-differ
reconstruction, prediction = self.model(augmented_image)
# Compute loss
loss = self.loss(input_image, reconstruction, anomaly_mask, prediction)

if self.sspcab:
loss += self.sspcab_lambda * self.sspcab_loss(
self.sspcab_activations["input"], self.sspcab_activations["output"]
)
return {"loss": loss}

def validation_step(self, batch, _):
Expand All @@ -80,7 +114,11 @@ class DraemLightning(Draem):
"""

def __init__(self, hparams: Union[DictConfig, ListConfig]):
super().__init__(anomaly_source_path=hparams.model.anomaly_source_path)
super().__init__(
enable_sspcab=hparams.model.enable_sspcab,
sspcab_lambda=hparams.model.sspcab_lambda,
anomaly_source_path=hparams.model.anomaly_source_path,
)
self.hparams: Union[DictConfig, ListConfig] # type: ignore
self.save_hyperparameters(hparams)

Expand Down
31 changes: 18 additions & 13 deletions anomalib/models/draem/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
import torch
from torch import Tensor, nn

from anomalib.models.components.layers import SSPCAB


class DraemModel(nn.Module):
"""DRAEM PyTorch model consisting of the reconstructive and discriminative sub networks."""

def __init__(self):
def __init__(self, sspcab: bool = False):
super().__init__()
self.reconstructive_subnetwork = ReconstructiveSubNetwork()
self.reconstructive_subnetwork = ReconstructiveSubNetwork(sspcab=sspcab)
self.discriminative_subnetwork = DiscriminativeSubNetwork(in_channels=6, out_channels=2)

def forward(self, batch: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]:
Expand Down Expand Up @@ -50,9 +52,9 @@ class ReconstructiveSubNetwork(nn.Module):
base_width (int): Base dimensionality of the layers of the autoencoder.
"""

def __init__(self, in_channels: int = 3, out_channels: int = 3, base_width=128):
def __init__(self, in_channels: int = 3, out_channels: int = 3, base_width=128, sspcab: bool = False):
super().__init__()
self.encoder = EncoderReconstructive(in_channels, base_width)
self.encoder = EncoderReconstructive(in_channels, base_width, sspcab=sspcab)
self.decoder = DecoderReconstructive(base_width, out_channels=out_channels)

def forward(self, batch: Tensor) -> Tensor:
Expand Down Expand Up @@ -321,7 +323,7 @@ class EncoderReconstructive(nn.Module):
base_width (int): Base dimensionality of the layers of the autoencoder.
"""

def __init__(self, in_channels: int, base_width: int):
def __init__(self, in_channels: int, base_width: int, sspcab: bool = False):
super().__init__()
self.block1 = nn.Sequential(
nn.Conv2d(in_channels, base_width, kernel_size=3, padding=1),
Expand Down Expand Up @@ -359,14 +361,17 @@ def __init__(self, in_channels: int, base_width: int):
nn.ReLU(inplace=True),
)
self.mp4 = nn.Sequential(nn.MaxPool2d(2))
self.block5 = nn.Sequential(
nn.Conv2d(base_width * 8, base_width * 8, kernel_size=3, padding=1),
nn.BatchNorm2d(base_width * 8),
nn.ReLU(inplace=True),
nn.Conv2d(base_width * 8, base_width * 8, kernel_size=3, padding=1),
nn.BatchNorm2d(base_width * 8),
nn.ReLU(inplace=True),
)
if sspcab:
self.block5 = SSPCAB(base_width * 8)
else:
self.block5 = nn.Sequential(
nn.Conv2d(base_width * 8, base_width * 8, kernel_size=3, padding=1),
nn.BatchNorm2d(base_width * 8),
nn.ReLU(inplace=True),
nn.Conv2d(base_width * 8, base_width * 8, kernel_size=3, padding=1),
nn.BatchNorm2d(base_width * 8),
nn.ReLU(inplace=True),
)

def forward(self, batch: Tensor) -> Tensor:
"""Encode a batch of input images to the salient space.
Expand Down

0 comments on commit b2b5ad4

Please sign in to comment.