Skip to content

Commit

Permalink
add antialiased to patchcore
Browse files Browse the repository at this point in the history
  • Loading branch information
alexriedel1 authored Dec 2, 2024
1 parent bb8bc2f commit 72214c5
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions src/anomalib/models/image/patchcore/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from anomalib.models.components import DynamicBufferMixin, KCenterGreedy, TimmFeatureExtractor

from .anomaly_map import AnomalyMapGenerator
from torchvision.models.feature_extraction import create_feature_extractor
import antialiased_cnns

if TYPE_CHECKING:
from anomalib.data.utils.tiler import Tiler
Expand Down Expand Up @@ -45,11 +47,20 @@ def __init__(
self.layers = layers
self.num_neighbors = num_neighbors

self.feature_extractor = TimmFeatureExtractor(
backbone=self.backbone,
pre_trained=pre_trained,
layers=self.layers,
if self.backbone == "antialiased_wide_resnet50_2":
model = antialiased_cnns.wide_resnet50_2(pretrained=True)
self.feature_extractor = create_feature_extractor(
model=model,
return_nodes={layer: layer for layer in self.layers},
tracer_kwargs={"leaf_modules": [BlurPool]}, # for models comes from antialias
).eval()
else:
self.feature_extractor = TimmFeatureExtractor(
backbone=self.backbone,
pre_trained=pre_trained,
layers=self.layers,
).eval()

self.feature_pooler = torch.nn.AvgPool2d(3, 1, 1)
self.anomaly_map_generator = AnomalyMapGenerator()

Expand Down

0 comments on commit 72214c5

Please sign in to comment.