diff --git a/.gitignore b/.gitignore index 89998fba46..0fffb7acc7 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +# IDE Settings files +.vscode/ + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/CHANGELOG.md b/CHANGELOG.md index 40fca95982..604b513890 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added `EMNISTDataModule`, `BinaryEMNISTDataModule`, and `BinaryEMNIST` dataset ([#676](https://github.com/PyTorchLightning/lightning-bolts/pull/676)) + + - Added Advantage Actor-Critic (A2C) Model [#598](https://github.com/PyTorchLightning/lightning-bolts/pull/598)) diff --git a/docs/source/datamodules_sklearn.rst b/docs/source/datamodules_sklearn.rst index feacd7ed48..1c53028269 100644 --- a/docs/source/datamodules_sklearn.rst +++ b/docs/source/datamodules_sklearn.rst @@ -45,4 +45,3 @@ Automatically generates the train, validation and test splits for a Numpy datase They are set up as dataloaders for convenience. Optionally, you can pass in your own validation and test splits. .. autoclass:: pl_bolts.datamodules.sklearn_datamodule.SklearnDataModule - :noindex: diff --git a/docs/source/datamodules_vision.rst b/docs/source/datamodules_vision.rst index 7cc4740244..c1de946472 100644 --- a/docs/source/datamodules_vision.rst +++ b/docs/source/datamodules_vision.rst @@ -5,47 +5,49 @@ The following are pre-built datamodules for computer-vision. ------------- Supervised learning --------------------- +------------------- These are standard vision datasets with the train, test, val splits pre-generated in DataLoaders with the standard transforms (and Normalization) values +BinaryEMNIST +^^^^^^^^^^^^ + +.. autoclass:: pl_bolts.datamodules.binary_emnist_datamodule.BinaryEMNISTDataModule BinaryMNIST ^^^^^^^^^^^ .. autoclass:: pl_bolts.datamodules.binary_mnist_datamodule.BinaryMNISTDataModule - :noindex: CityScapes ^^^^^^^^^^ .. autoclass:: pl_bolts.datamodules.cityscapes_datamodule.CityscapesDataModule - :noindex: CIFAR-10 ^^^^^^^^ .. autoclass:: pl_bolts.datamodules.cifar10_datamodule.CIFAR10DataModule - :noindex: + +EMNIST +^^^^^^ + +.. autoclass:: pl_bolts.datamodules.emnist_datamodule.EMNISTDataModule FashionMNIST ^^^^^^^^^^^^ .. autoclass:: pl_bolts.datamodules.fashion_mnist_datamodule.FashionMNISTDataModule - :noindex: - Imagenet ^^^^^^^^ .. autoclass:: pl_bolts.datamodules.imagenet_datamodule.ImagenetDataModule - :noindex: MNIST ^^^^^ .. autoclass:: pl_bolts.datamodules.mnist_datamodule.MNISTDataModule - :noindex: Semi-supervised learning ------------------------ @@ -56,10 +58,8 @@ Imagenet (ssl) ^^^^^^^^^^^^^^ .. autoclass:: pl_bolts.datamodules.ssl_imagenet_datamodule.SSLImagenetDataModule - :noindex: STL-10 ^^^^^^ .. autoclass:: pl_bolts.datamodules.stl10_datamodule.STL10DataModule - :noindex: diff --git a/pl_bolts/datamodules/__init__.py b/pl_bolts/datamodules/__init__.py index e608d71010..b345e24ec8 100644 --- a/pl_bolts/datamodules/__init__.py +++ b/pl_bolts/datamodules/__init__.py @@ -1,7 +1,9 @@ from pl_bolts.datamodules.async_dataloader import AsynchronousLoader +from pl_bolts.datamodules.binary_emnist_datamodule import BinaryEMNISTDataModule from pl_bolts.datamodules.binary_mnist_datamodule import BinaryMNISTDataModule from pl_bolts.datamodules.cifar10_datamodule import CIFAR10DataModule, TinyCIFAR10DataModule from pl_bolts.datamodules.cityscapes_datamodule import CityscapesDataModule +from pl_bolts.datamodules.emnist_datamodule import EMNISTDataModule from pl_bolts.datamodules.experience_source import DiscountedExperienceSource, ExperienceSource, ExperienceSourceDataset from pl_bolts.datamodules.fashion_mnist_datamodule import FashionMNISTDataModule from pl_bolts.datamodules.imagenet_datamodule import ImagenetDataModule @@ -33,4 +35,6 @@ 'STL10DataModule', 'VOCDetectionDataModule', 'KittiDataset', + 'EMNISTDataModule', + 'BinaryEMNISTDataModule', ] diff --git a/pl_bolts/datamodules/binary_emnist_datamodule.py b/pl_bolts/datamodules/binary_emnist_datamodule.py new file mode 100644 index 0000000000..9dbd6d2040 --- /dev/null +++ b/pl_bolts/datamodules/binary_emnist_datamodule.py @@ -0,0 +1,81 @@ +from typing import Any, Optional, Union + +from pl_bolts.datamodules.emnist_datamodule import EMNISTDataModule +from pl_bolts.datasets import BinaryEMNIST +from pl_bolts.utils import _TORCHVISION_AVAILABLE + + +class BinaryEMNISTDataModule(EMNISTDataModule): + """ + .. figure:: https://user-images.githubusercontent.com/4632336/123210742-4d6b3380-d477-11eb-80da-3e9a74a18a07.png + :width: 400 + :alt: EMNIST + + Please see :class:`~pl_bolts.datamodules.emnist_datamodule.EMNISTDataModule` for more details. + + Example:: + + from pl_bolts.datamodules import BinaryEMNISTDataModule + dm = BinaryEMNISTDataModule('.') + model = LitModel() + Trainer().fit(model, datamodule=dm) + """ + name = "binary_emnist" + dataset_cls = BinaryEMNIST + dims = (1, 28, 28) + + def __init__( + self, + data_dir: Optional[str] = None, + split: str = 'mnist', + val_split: Union[int, float] = 0.2, + num_workers: int = 0, + normalize: bool = False, + batch_size: int = 32, + seed: int = 42, + shuffle: bool = True, + pin_memory: bool = True, + drop_last: bool = False, + strict_val_split: bool = False, + *args: Any, + **kwargs: Any, + ) -> None: + """ + Args: + data_dir: Where to save/load the data. + split: The dataset has 6 different splits: ``byclass``, ``bymerge``, + ``balanced``, ``letters``, ``digits`` and ``mnist``. + This argument is passed to :class:`torchvision.datasets.EMNIST`. + val_split: Percent (float) or number (int) of samples + to use for the validation split. + num_workers: How many workers to use for loading data + normalize: If ``True``, applies image normalize. + batch_size: How many samples per batch to load. + seed: Random seed to be used for train/val/test splits. + shuffle: If ``True``, shuffles the train data every epoch. + pin_memory: If ``True``, the data loader will copy Tensors into + CUDA pinned memory before returning them. + drop_last: If ``True``, drops the last incomplete batch. + strict_val_split: If ``True``, uses the validation split defined in the paper and ignores ``val_split``. + Note that it only works with ``"balanced"``, ``"digits"``, ``"letters"``, ``"mnist"`` splits. + """ + if not _TORCHVISION_AVAILABLE: # pragma: no cover + raise ModuleNotFoundError( + 'You want to use EMNIST dataset loaded from `torchvision` which is not installed yet.' + ) + + super(BinaryEMNISTDataModule, self).__init__( # type: ignore[misc] + data_dir=data_dir, + split=split, + val_split=val_split, + num_workers=num_workers, + normalize=normalize, + batch_size=batch_size, + seed=seed, + shuffle=shuffle, + pin_memory=pin_memory, + drop_last=drop_last, + strict_val_split=strict_val_split, + *args, + **kwargs, + ) diff --git a/pl_bolts/datamodules/emnist_datamodule.py b/pl_bolts/datamodules/emnist_datamodule.py new file mode 100644 index 0000000000..2288155953 --- /dev/null +++ b/pl_bolts/datamodules/emnist_datamodule.py @@ -0,0 +1,217 @@ +from typing import Any, Callable, Optional, Union + +from pl_bolts.datamodules.vision_datamodule import VisionDataModule +from pl_bolts.transforms.dataset_normalizations import emnist_normalization +from pl_bolts.utils import _TORCHVISION_AVAILABLE +from pl_bolts.utils.warnings import warn_missing_pkg + +if _TORCHVISION_AVAILABLE: + from torchvision import transforms as transform_lib + from torchvision.datasets import EMNIST +else: # pragma: no cover + warn_missing_pkg('torchvision') + EMNIST = object + + +class EMNISTDataModule(VisionDataModule): + """ + .. figure:: https://user-images.githubusercontent.com/4632336/123210742-4d6b3380-d477-11eb-80da-3e9a74a18a07.png + :width: 400 + :alt: EMNIST + + .. list-table:: Dataset information (source: `EMNIST: an extension of MNIST to handwritten + letters `_ [Table-II]) + :header-rows: 1 + + * - Split Name + - No. classes + - Train set size + - Test set size + - Validation set + - Total size + * - ``"byclass"`` + - 62 + - 697,932 + - 116,323 + - No + - 814,255 + * - ``"byclass"`` + - 62 + - 697,932 + - 116,323 + - No + - 814,255 + * - ``"bymerge"`` + - 47 + - 697,932 + - 116,323 + - No + - 814,255 + * - ``"balanced"`` + - 47 + - 112,800 + - 18,800 + - Yes + - 131,600 + * - ``"digits"`` + - 10 + - 240,000 + - 40,000 + - Yes + - 280,000 + * - ``"letters"`` + - 37 + - 88,800 + - 14,800 + - Yes + - 103,600 + * - ``"mnist"`` + - 10 + - 60,000 + - 10,000 + - Yes + - 70,000 + + | + + Here is the default EMNIST, train, val, test-splits and transforms. + + Transforms:: + + emnist_transforms = transform_lib.Compose([ + transform_lib.ToTensor(), + ]) + + Example:: + + from pl_bolts.datamodules import EMNISTDataModule + dm = EMNISTDataModule('.') + model = LitModel() + Trainer().fit(model, datamodule=dm) + """ + name = "emnist" + dataset_cls = EMNIST + dims = (1, 28, 28) + + _official_val_split = { + 'balanced': 18_800, + 'digits': 40_000, + 'letters': 14_800, + 'mnist': 10_000, + } + + def __init__( + self, + data_dir: Optional[str] = None, + split: str = 'mnist', + val_split: Union[int, float] = 0.2, + num_workers: int = 0, + normalize: bool = False, + batch_size: int = 32, + seed: int = 42, + shuffle: bool = True, + pin_memory: bool = True, + drop_last: bool = False, + strict_val_split: bool = False, + *args: Any, + **kwargs: Any, + ) -> None: + """ + Args: + data_dir: Where to save/load the data. + split: The dataset has 6 different splits: ``byclass``, ``bymerge``, + ``balanced``, ``letters``, ``digits`` and ``mnist``. + This argument is passed to :class:`torchvision.datasets.EMNIST`. + val_split: Percent (float) or number (int) of samples + to use for the validation split. + num_workers: How many workers to use for loading data + normalize: If ``True``, applies image normalize. + batch_size: How many samples per batch to load. + seed: Random seed to be used for train/val/test splits. + shuffle: If ``True``, shuffles the train data every epoch. + pin_memory: If ``True``, the data loader will copy Tensors into + CUDA pinned memory before returning them. + drop_last: If ``True``, drops the last incomplete batch. + strict_val_split: If ``True``, uses the validation split defined in the paper and ignores ``val_split``. + Note that it only works with ``"balanced"``, ``"digits"``, ``"letters"``, ``"mnist"`` splits. + """ + if not _TORCHVISION_AVAILABLE: # pragma: no cover + raise ModuleNotFoundError( + 'You want to use MNIST dataset loaded from `torchvision` which is not installed yet.' + ) + + if split not in self.dataset_cls.splits: + raise ValueError( + f"Unknown value '{split}' for argument `split`. Valid values are {self.dataset_cls.splits}." + ) + + super(EMNISTDataModule, self).__init__( # type: ignore[misc] + data_dir=data_dir, + val_split=val_split, + num_workers=num_workers, + normalize=normalize, + batch_size=batch_size, + seed=seed, + shuffle=shuffle, + pin_memory=pin_memory, + drop_last=drop_last, + *args, + **kwargs, + ) + self.split = split + + if strict_val_split: + # replaces the value in `val_split` with the one defined in the paper + if self.split in self._official_val_split: + self.val_split = self._official_val_split[self.split] + else: + raise ValueError( + f"Invalid value '{self.split}' for argument `split` with `strict_val_split=True`. " + f"Valid values are {set(self._official_val_split)}." + ) + + @property + def num_classes(self) -> int: + """Returns the number of classes. See the table above.""" + return len(self.dataset_cls.classes_split_dict[self.split]) + + def prepare_data(self, *args: Any, **kwargs: Any) -> None: + """ + Saves files to ``data_dir``. + """ + + self.dataset_cls(self.data_dir, split=self.split, train=True, download=True) + self.dataset_cls(self.data_dir, split=self.split, train=False, download=True) + + def setup(self, stage: Optional[str] = None) -> None: + """ + Creates train, val, and test dataset + """ + + if stage == "fit" or stage is None: + train_transforms = self.default_transforms() if self.train_transforms is None else self.train_transforms + val_transforms = self.default_transforms() if self.val_transforms is None else self.val_transforms + + dataset_train = self.dataset_cls( + self.data_dir, split=self.split, train=True, transform=train_transforms, **self.EXTRA_ARGS + ) + dataset_val = self.dataset_cls( + self.data_dir, split=self.split, train=True, transform=val_transforms, **self.EXTRA_ARGS + ) + + # Split + self.dataset_train = self._split_dataset(dataset_train) + self.dataset_val = self._split_dataset(dataset_val, train=False) + + if stage == "test" or stage is None: + test_transforms = self.default_transforms() if self.test_transforms is None else self.test_transforms + self.dataset_test = self.dataset_cls( + self.data_dir, split=self.split, train=False, transform=test_transforms, **self.EXTRA_ARGS + ) + + def default_transforms(self) -> Callable: + + return transform_lib.Compose([ + transform_lib.ToTensor(), + emnist_normalization(self.split), + ]) if self.normalize else transform_lib.Compose([transform_lib.ToTensor()]) diff --git a/pl_bolts/datasets/__init__.py b/pl_bolts/datasets/__init__.py index b7ac6c5fee..42b6fe837a 100644 --- a/pl_bolts/datasets/__init__.py +++ b/pl_bolts/datasets/__init__.py @@ -10,6 +10,7 @@ RandomDictDataset, RandomDictStringDataset, ) +from pl_bolts.datasets.emnist_dataset import BinaryEMNIST from pl_bolts.datasets.imagenet_dataset import extract_archive, parse_devkit_archive, UnlabeledImagenet from pl_bolts.datasets.kitti_dataset import KittiDataset from pl_bolts.datasets.mnist_dataset import BinaryMNIST, MNIST @@ -33,6 +34,7 @@ "BinaryMNIST", "CIFAR10Mixed", "SSLDatasetMixin", + "BinaryEMNIST", ] # TorchVision hotfix https://github.com/pytorch/vision/issues/1938 diff --git a/pl_bolts/datasets/emnist_dataset.py b/pl_bolts/datasets/emnist_dataset.py new file mode 100644 index 0000000000..2cec471cb4 --- /dev/null +++ b/pl_bolts/datasets/emnist_dataset.py @@ -0,0 +1,45 @@ +from pl_bolts.utils import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE +from pl_bolts.utils.warnings import warn_missing_pkg + +if _TORCHVISION_AVAILABLE: + from torchvision.datasets import EMNIST +else: # pragma: no cover + warn_missing_pkg('torchvision') + EMNIST = object + +if _PIL_AVAILABLE: + from PIL import Image +else: # pragma: no cover + warn_missing_pkg('PIL', pypi_name='Pillow') + + +class BinaryEMNIST(EMNIST): + + def __getitem__(self, idx): + """ + Args: + index: Index + + Returns: + tuple: (image, target) where target is index of the target class. + """ + if not _TORCHVISION_AVAILABLE: # pragma: no cover + raise ModuleNotFoundError('You want to use `torchvision` which is not installed yet.') + + img, target = self.data[idx], int(self.targets[idx]) + + # doing this so that it is consistent with all other datasets + # to return a PIL Image + img = Image.fromarray(img.numpy(), mode='L') + + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + target = self.target_transform(target) + + # binary + img[img < 0.5] = 0.0 + img[img >= 0.5] = 1.0 + + return img, target diff --git a/pl_bolts/transforms/dataset_normalizations.py b/pl_bolts/transforms/dataset_normalizations.py index f07447c82b..00e4a0413f 100644 --- a/pl_bolts/transforms/dataset_normalizations.py +++ b/pl_bolts/transforms/dataset_normalizations.py @@ -38,3 +38,22 @@ def stl10_normalization(): normalize = transforms.Normalize(mean=(0.43, 0.42, 0.39), std=(0.27, 0.26, 0.27)) return normalize + + +def emnist_normalization(split: str): + if not _TORCHVISION_AVAILABLE: # pragma: no cover + raise ModuleNotFoundError( + 'You want to use `torchvision` which is not installed yet, install it with `pip install torchvision`.' + ) + + # `stats` contains mean and std for each `split`. + stats = { + 'balanced': (0.175, 0.333), + 'byclass': (0.174, 0.332), + 'bymerge': (0.174, 0.332), + 'digits': (0.173, 0.332), + 'letters': (0.172, 0.331), + 'mnist': (0.173, 0.332), + } + + return transforms.Normalize(mean=stats[split][0], std=stats[split][1]) diff --git a/tests/datamodules/test_datamodules.py b/tests/datamodules/test_datamodules.py index 8d18b77b3e..6cc7a20b1a 100644 --- a/tests/datamodules/test_datamodules.py +++ b/tests/datamodules/test_datamodules.py @@ -6,9 +6,11 @@ from PIL import Image from pl_bolts.datamodules import ( + BinaryEMNISTDataModule, BinaryMNISTDataModule, CIFAR10DataModule, CityscapesDataModule, + EMNISTDataModule, FashionMNISTDataModule, MNISTDataModule, ) @@ -80,8 +82,54 @@ def test_data_modules(datadir, dm_cls): assert img.size() == torch.Size([2, *dm.size()]) -def _create_dm(dm_cls, datadir, val_split=0.2): - dm = dm_cls(data_dir=datadir, val_split=val_split, num_workers=1, batch_size=2) +def _create_dm(dm_cls, datadir, **kwargs): + dm = dm_cls(data_dir=datadir, num_workers=1, batch_size=2, **kwargs) dm.prepare_data() dm.setup() return dm + + +@pytest.mark.parametrize("split", ["byclass", "bymerge", "balanced", "letters", "digits", "mnist"]) +@pytest.mark.parametrize("dm_cls", [BinaryEMNISTDataModule, EMNISTDataModule]) +def test_emnist_datamodules(datadir, dm_cls, split): + """Test EMNIST datamodules download data and have the correct shape.""" + + dm = _create_dm(dm_cls, datadir, split=split) + loader = dm.train_dataloader() + img, _ = next(iter(loader)) + assert img.size() == torch.Size([2, 1, 28, 28]) + + +@pytest.mark.parametrize("dm_cls", [BinaryEMNISTDataModule, EMNISTDataModule]) +def test_emnist_datamodules_with_invalid_split(datadir, dm_cls): + """Test EMNIST datamodules raise an exception if the provided `split` doesn't exist.""" + + with pytest.raises(ValueError, match="Unknown value"): + dm_cls(data_dir=datadir, split="this_split_doesnt_exist") + + +@pytest.mark.parametrize("dm_cls", [BinaryEMNISTDataModule, EMNISTDataModule]) +@pytest.mark.parametrize( + "split, expected_val_split", [ + ("byclass", None), + ("bymerge", None), + ("balanced", 18_800), + ("digits", 40_000), + ("letters", 14_800), + ("mnist", 10_000), + ] +) +def test_emnist_datamodules_with_strict_val_split(datadir, dm_cls, split, expected_val_split): + """ + Test EMNIST datamodules when strict_val_split is specified to use the validation set defined in the paper. + Refer to https://arxiv.org/abs/1702.05373 for `expected_val_split` values. + """ + + if expected_val_split is None: + with pytest.raises(ValueError, match="Invalid value"): + dm = _create_dm(dm_cls, datadir, split=split, strict_val_split=True) + + else: + dm = _create_dm(dm_cls, datadir, split=split, strict_val_split=True) + assert dm.val_split == expected_val_split + assert len(dm.dataset_val) == expected_val_split