Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed shape error, allowing arbitary image sizes for EfficientAD #1537

Merged
merged 6 commits into from
Jan 23, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions src/anomalib/models/efficient_ad/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from __future__ import annotations

import logging
import math
import random
from enum import Enum

Expand Down Expand Up @@ -147,9 +148,10 @@ class Decoder(nn.Module):
def __init__(self, out_channels, padding, img_size, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.img_size = img_size
# use ceil to match output shape of PDN
self.last_upsample = (
int(img_size[0] / 4) if padding else int(img_size[0] / 4) - 8,
int(img_size[1] / 4) if padding else int(img_size[1] / 4) - 8,
math.ceil(img_size[0] / 4) if padding else math.ceil(img_size[0] / 4) - 8,
math.ceil(img_size[1] / 4) if padding else math.ceil(img_size[1] / 4) - 8,
holzweber marked this conversation as resolved.
Show resolved Hide resolved
)
self.deconv1 = nn.Conv2d(64, 64, kernel_size=4, stride=1, padding=2)
self.deconv2 = nn.Conv2d(64, 64, kernel_size=4, stride=1, padding=2)
Expand All @@ -167,22 +169,22 @@ def __init__(self, out_channels, padding, img_size, *args, **kwargs) -> None:
self.dropout6 = nn.Dropout(p=0.2)

def forward(self, x):
x = F.interpolate(x, size=(int(self.img_size[0] / 64) - 1, int(self.img_size[1] / 64) - 1), mode="bilinear")
x = F.interpolate(x, size=(self.img_size[0] // 64 - 1, self.img_size[1] // 64 - 1), mode="bilinear")
x = F.relu(self.deconv1(x))
x = self.dropout1(x)
x = F.interpolate(x, size=(int(self.img_size[0] / 32), int(self.img_size[1] / 32)), mode="bilinear")
x = F.interpolate(x, size=(self.img_size[0] // 32, self.img_size[1] // 32), mode="bilinear")
x = F.relu(self.deconv2(x))
x = self.dropout2(x)
x = F.interpolate(x, size=(int(self.img_size[0] / 16) - 1, int(self.img_size[1] / 16) - 1), mode="bilinear")
x = F.interpolate(x, size=(self.img_size[0] // 16 - 1, self.img_size[1] // 16 - 1), mode="bilinear")
x = F.relu(self.deconv3(x))
x = self.dropout3(x)
x = F.interpolate(x, size=(int(self.img_size[0] / 8), int(self.img_size[1] / 8)), mode="bilinear")
x = F.interpolate(x, size=(self.img_size[0] // 8, self.img_size[1] // 8), mode="bilinear")
x = F.relu(self.deconv4(x))
x = self.dropout4(x)
x = F.interpolate(x, size=(int(self.img_size[0] / 4) - 1, int(self.img_size[1] / 4) - 1), mode="bilinear")
x = F.interpolate(x, size=(self.img_size[0] // 4 - 1, self.img_size[1] // 4 - 1), mode="bilinear")
x = F.relu(self.deconv5(x))
x = self.dropout5(x)
x = F.interpolate(x, size=(int(self.img_size[0] / 2) - 1, int(self.img_size[1] / 2) - 1), mode="bilinear")
x = F.interpolate(x, size=(self.img_size[0] // 2 - 1, self.img_size[1] // 2 - 1), mode="bilinear")
x = F.relu(self.deconv6(x))
x = self.dropout6(x)
x = F.interpolate(x, size=self.last_upsample, mode="bilinear")
Expand Down
Loading