Skip to content

Commit

Permalink
Added Fashion MNIST support
Browse files Browse the repository at this point in the history
  • Loading branch information
ancestor-mithril committed May 22, 2024
1 parent 11c4f56 commit 1b4e77f
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
5 changes: 4 additions & 1 deletion utils/dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from functools import partial

from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import CIFAR10, CIFAR100
from torchvision.datasets import CIFAR10, CIFAR100, FashionMNIST

from .transforms import init_transforms

Expand Down Expand Up @@ -36,6 +36,9 @@ def init_dataset(args):
elif args.dataset == 'cifar100':
dataset_fn = partial(CIFAR100, root=args.data_path, download=True)
num_classes = 100
elif args.dataset == 'FashionMNIST':
dataset_fn = partial(FashionMNIST, root=args.data_path, download=True)
num_classes = 10
else:
raise NotImplementedError(f'Dataset {args.dataset} not implemented')

Expand Down
2 changes: 1 addition & 1 deletion utils/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,6 @@ def test_runtime(self):


def init_transforms(args) -> DatasetTransforms:
if args.dataset in ('cifar10', 'cifar100'):
if args.dataset in ('cifar10', 'cifar100', 'FashionMNIST'):
return CifarTransforms(args)
raise NotImplementedError(f"Transforms not implemented for {args.dataset}")

0 comments on commit 1b4e77f

Please sign in to comment.