Skip to content

Commit

Permalink
Removed Imagenet normalization, added check-up for train_batch_size
Browse files Browse the repository at this point in the history
  • Loading branch information
abc-125 committed Apr 22, 2024
1 parent 3d80649 commit 2d59cb1
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions src/anomalib/models/image/efficient_ad/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(
padding=padding,
pad_maps=pad_maps,
)
self.batch_size = 1 # imagenet dataloader batch size is 1 according to the paper
self.batch_size = 1 # imagenet dataloader batch_size is 1 according to the paper
self.lr = lr
self.weight_decay = weight_decay

Expand Down Expand Up @@ -234,9 +234,18 @@ def configure_optimizers(self) -> torch.optim.Optimizer:
def on_train_start(self) -> None:
"""Called before the first training epoch.
First sets up the pretrained teacher model, then prepares the imagenette data, and finally calculates or
loads the channel-wise mean and std of the training dataset and push to the model.
First check if EfficientAd-specific parameters are set correctly (train_batch_size of 1
and no Imagenet normalization in transforms), then sets up the pretrained teacher model,
then prepares the imagenette data, and finally calculates or loads
the channel-wise mean and std of the training dataset and push to the model.
"""
if self.trainer.datamodule.train_batch_size != 1:
msg = "train_batch_size for EfficientAd should be 1."
raise ValueError(msg)
if self._transform and any(isinstance(transform, Normalize) for transform in self._transform.transforms):
msg = "Transforms for EfficientAd should not contain Normalize."
raise ValueError(msg)

sample = next(iter(self.trainer.train_dataloader))
image_size = sample["image"].shape[-2:]
self.prepare_pretrained_model()
Expand Down Expand Up @@ -311,11 +320,10 @@ def learning_type(self) -> LearningType:
return LearningType.ONE_CLASS

def configure_transforms(self, image_size: tuple[int, int] | None = None) -> Transform:
"""Default transform for Padim."""
"""Default transform for EfficientAd. Imagenet normalization applied in forward."""
image_size = image_size or (256, 256)
return Compose(
[
Resize(image_size, antialias=True),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
],
)

0 comments on commit 2d59cb1

Please sign in to comment.