diff --git a/src/anomalib/models/image/patchcore/torch_model.py b/src/anomalib/models/image/patchcore/torch_model.py index a2ceb32b91..89df70c68d 100644 --- a/src/anomalib/models/image/patchcore/torch_model.py +++ b/src/anomalib/models/image/patchcore/torch_model.py @@ -13,6 +13,8 @@ from anomalib.models.components import DynamicBufferMixin, KCenterGreedy, TimmFeatureExtractor from .anomaly_map import AnomalyMapGenerator +from torchvision.models.feature_extraction import create_feature_extractor +import antialiased_cnns if TYPE_CHECKING: from anomalib.data.utils.tiler import Tiler @@ -45,11 +47,20 @@ def __init__( self.layers = layers self.num_neighbors = num_neighbors - self.feature_extractor = TimmFeatureExtractor( - backbone=self.backbone, - pre_trained=pre_trained, - layers=self.layers, + if self.backbone == "antialiased_wide_resnet50_2": + model = antialiased_cnns.wide_resnet50_2(pretrained=True) + self.feature_extractor = create_feature_extractor( + model=model, + return_nodes={layer: layer for layer in self.layers}, + tracer_kwargs={"leaf_modules": [BlurPool]}, # for models comes from antialias ).eval() + else: + self.feature_extractor = TimmFeatureExtractor( + backbone=self.backbone, + pre_trained=pre_trained, + layers=self.layers, + ).eval() + self.feature_pooler = torch.nn.AvgPool2d(3, 1, 1) self.anomaly_map_generator = AnomalyMapGenerator()