diff --git a/utils/transforms.py b/utils/transforms.py index 86638ac..16b528a 100644 --- a/utils/transforms.py +++ b/utils/transforms.py @@ -104,9 +104,18 @@ def test_runtime(self): return None +class FashionMNISTTransforms(CifarTransforms): + def __init__(self, args): + super().__init__(args) + self.normalize = v2.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) + # FashionMNIST must not have normalize inplace, otherwise it fails because it is a single channel image + + def init_transforms(args) -> DatasetTransforms: - if args.dataset in ("cifar10", "cifar100", "FashionMNIST", "cifar100noisy"): + if args.dataset in ("cifar10", "cifar100", "cifar100noisy"): return CifarTransforms(args) + if args.dataset in ("FashionMNIST", ): + return FashionMNISTTransforms(args) if args.dataset in ("MNIST", "DirtyMNIST"): return MNISTTransforms(args) raise NotImplementedError(f"Transforms not implemented for {args.dataset}")