From 72214c57d323f18fb8bfb28df33a837675d722bc Mon Sep 17 00:00:00 2001 From: Alexander Riedel <54716527+alexriedel1@users.noreply.github.com> Date: Mon, 2 Dec 2024 13:22:37 +0100 Subject: [PATCH] add antialiased to patchcore --- .../models/image/patchcore/torch_model.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/src/anomalib/models/image/patchcore/torch_model.py b/src/anomalib/models/image/patchcore/torch_model.py index a2ceb32b91..89df70c68d 100644 --- a/src/anomalib/models/image/patchcore/torch_model.py +++ b/src/anomalib/models/image/patchcore/torch_model.py @@ -13,6 +13,8 @@ from anomalib.models.components import DynamicBufferMixin, KCenterGreedy, TimmFeatureExtractor from .anomaly_map import AnomalyMapGenerator +from torchvision.models.feature_extraction import create_feature_extractor +import antialiased_cnns if TYPE_CHECKING: from anomalib.data.utils.tiler import Tiler @@ -45,11 +47,20 @@ def __init__( self.layers = layers self.num_neighbors = num_neighbors - self.feature_extractor = TimmFeatureExtractor( - backbone=self.backbone, - pre_trained=pre_trained, - layers=self.layers, + if self.backbone == "antialiased_wide_resnet50_2": + model = antialiased_cnns.wide_resnet50_2(pretrained=True) + self.feature_extractor = create_feature_extractor( + model=model, + return_nodes={layer: layer for layer in self.layers}, + tracer_kwargs={"leaf_modules": [BlurPool]}, # for models comes from antialias ).eval() + else: + self.feature_extractor = TimmFeatureExtractor( + backbone=self.backbone, + pre_trained=pre_trained, + layers=self.layers, + ).eval() + self.feature_pooler = torch.nn.AvgPool2d(3, 1, 1) self.anomaly_map_generator = AnomalyMapGenerator()