From 97ae47a8ec7d2d09cafbdc5b53c137b71017ce38 Mon Sep 17 00:00:00 2001 From: Alexander <47296670+Marsmaennchen221@users.noreply.github.com> Date: Fri, 12 Jul 2024 09:28:14 +0200 Subject: [PATCH] feat: added MNIST, Fashion-MNIST and KMNIST datasets (#164) Closes #161 Closes #162 Closes #163 ### Summary of Changes feat: added MNIST, Fashion-MNIST and KMNIST datasets build: bump safe-ds to ^0.24.0 --------- Co-authored-by: megalinter-bot <129584137+megalinter-bot@users.noreply.github.com> Co-authored-by: Lars Reimann --- poetry.lock | 4 +- pyproject.toml | 2 +- src/safeds_datasets/image/__init__.py | 5 + src/safeds_datasets/image/_mnist/__init__.py | 5 + src/safeds_datasets/image/_mnist/_mnist.py | 256 ++++++++++++++++++ tests/safeds_datasets/image/__init__.py | 0 .../safeds_datasets/image/_mnist/__init__.py | 0 .../image/_mnist/test_mnist.py | 98 +++++++ 8 files changed, 366 insertions(+), 4 deletions(-) create mode 100644 src/safeds_datasets/image/__init__.py create mode 100644 src/safeds_datasets/image/_mnist/__init__.py create mode 100644 src/safeds_datasets/image/_mnist/_mnist.py create mode 100644 tests/safeds_datasets/image/__init__.py create mode 100644 tests/safeds_datasets/image/_mnist/__init__.py create mode 100644 tests/safeds_datasets/image/_mnist/test_mnist.py diff --git a/poetry.lock b/poetry.lock index e9b4c52..711a876 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2365,7 +2365,6 @@ description = "Nvidia JIT LTO Library" optional = false python-versions = ">=3" files = [ - {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:4abe7fef64914ccfa909bc2ba39739670ecc9e820c83ccc7a6ed414122599b83"}, {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57"}, {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:fd9020c501d27d135f983c6d3e244b197a7ccad769e34df53a42e276b0e25fa1"}, ] @@ -3038,7 +3037,6 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -4596,4 +4594,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.11,<3.13" -content-hash = "eb8f77d983defacb62407e70d19cf8b5403a4fd19ba489c3a4acac880940d886" +content-hash = "3d25591445f9ba7fd0ba42364e8486704f822448e5688e7991138700c1ab34ca" diff --git a/pyproject.toml b/pyproject.toml index 1602d27..35e7b3a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ packages = [ [tool.poetry.dependencies] python = "^3.11,<3.13" -safe-ds = ">=0.17,<0.27" +safe-ds = ">=0.24,<0.27" [tool.poetry.group.dev.dependencies] pytest = ">=7.2.1,<9.0.0" diff --git a/src/safeds_datasets/image/__init__.py b/src/safeds_datasets/image/__init__.py new file mode 100644 index 0000000..4e9487b --- /dev/null +++ b/src/safeds_datasets/image/__init__.py @@ -0,0 +1,5 @@ +"""Image datasets.""" + +from ._mnist import load_fashion_mnist, load_kmnist, load_mnist + +__all__ = ["load_fashion_mnist", "load_kmnist", "load_mnist"] diff --git a/src/safeds_datasets/image/_mnist/__init__.py b/src/safeds_datasets/image/_mnist/__init__.py new file mode 100644 index 0000000..7491a04 --- /dev/null +++ b/src/safeds_datasets/image/_mnist/__init__.py @@ -0,0 +1,5 @@ +"""MNIST like Datasets.""" + +from ._mnist import load_fashion_mnist, load_kmnist, load_mnist + +__all__ = ["load_fashion_mnist", "load_kmnist", "load_mnist"] diff --git a/src/safeds_datasets/image/_mnist/_mnist.py b/src/safeds_datasets/image/_mnist/_mnist.py new file mode 100644 index 0000000..15419bd --- /dev/null +++ b/src/safeds_datasets/image/_mnist/_mnist.py @@ -0,0 +1,256 @@ +import gzip +import os +import struct +import sys +import urllib.request +from array import array +from pathlib import Path +from typing import TYPE_CHECKING +from urllib.error import HTTPError + +import torch +from safeds._config import _init_default_device +from safeds.data.image.containers._single_size_image_list import _SingleSizeImageList +from safeds.data.labeled.containers import ImageDataset +from safeds.data.tabular.containers import Column + +if TYPE_CHECKING: + from safeds.data.image.containers import ImageList + +_mnist_links: list[str] = ["http://yann.lecun.com/exdb/mnist/", "https://ossci-datasets.s3.amazonaws.com/mnist/"] +_mnist_files: dict[str, str] = { + "train-images-idx3": "train-images-idx3-ubyte.gz", + "train-labels-idx1": "train-labels-idx1-ubyte.gz", + "test-images-idx3": "t10k-images-idx3-ubyte.gz", + "test-labels-idx1": "t10k-labels-idx1-ubyte.gz", +} +_mnist_labels: dict[int, str] = {0: "0", 1: "1", 2: "2", 3: "3", 4: "4", 5: "5", 6: "6", 7: "7", 8: "8", 9: "9"} +_mnist_folder: str = "mnist" + +_fashion_mnist_links: list[str] = ["http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/"] +_fashion_mnist_files: dict[str, str] = _mnist_files +_fashion_mnist_labels: dict[int, str] = { + 0: "T-shirt/top", + 1: "Trouser", + 2: "Pullover", + 3: "Dress", + 4: "Coat", + 5: "Sandal", + 6: "Shirt", + 7: "Sneaker", + 8: "Bag", + 9: "Ankle boot", +} +_fashion_mnist_folder: str = "fashion-mnist" + +_kuzushiji_mnist_links: list[str] = ["http://codh.rois.ac.jp/kmnist/dataset/kmnist/"] +_kuzushiji_mnist_files: dict[str, str] = _mnist_files +_kuzushiji_mnist_labels: dict[int, str] = { + 0: "\u304a", + 1: "\u304d", + 2: "\u3059", + 3: "\u3064", + 4: "\u306a", + 5: "\u306f", + 6: "\u307e", + 7: "\u3084", + 8: "\u308c", + 9: "\u3092", +} +_kuzushiji_mnist_folder: str = "kmnist" + + +def load_mnist(path: str | Path, download: bool = True) -> tuple[ImageDataset[Column], ImageDataset[Column]]: + """ + Load the `MNIST `_ datasets. + + Parameters + ---------- + path: + the path were the files are stored or will be downloaded to + download: + whether the files should be downloaded to the given path + + Returns + ------- + train_dataset, test_dataset: + The train and test datasets. + + Raises + ------ + FileNotFoundError + if a file of the dataset cannot be found + """ + path = Path(path) / _mnist_folder + path.mkdir(parents=True, exist_ok=True) + path_files = os.listdir(path) + missing_files = [] + for file_path in _mnist_files.values(): + if file_path not in path_files: + missing_files.append(file_path) + if len(missing_files) > 0: + if download: + _download_mnist_like( + path, + {name: f_path for name, f_path in _mnist_files.items() if f_path in missing_files}, + _mnist_links, + ) + else: + raise FileNotFoundError(f"Could not find files {[str(path / file) for file in missing_files]}") + return _load_mnist_like(path, _mnist_files, _mnist_labels) + + +def load_fashion_mnist(path: str | Path, download: bool = True) -> tuple[ImageDataset[Column], ImageDataset[Column]]: + """ + Load the `Fashion-MNIST `_ datasets. + + Parameters + ---------- + path: + the path were the files are stored or will be downloaded to + download: + whether the files should be downloaded to the given path + + Returns + ------- + train_dataset, test_dataset: + The train and test datasets. + + Raises + ------ + FileNotFoundError + if a file of the dataset cannot be found + """ + path = Path(path) / _fashion_mnist_folder + path.mkdir(parents=True, exist_ok=True) + path_files = os.listdir(path) + missing_files = [] + for file_path in _fashion_mnist_files.values(): + if file_path not in path_files: + missing_files.append(file_path) + if len(missing_files) > 0: + if download: + _download_mnist_like( + path, + {name: f_path for name, f_path in _fashion_mnist_files.items() if f_path in missing_files}, + _fashion_mnist_links, + ) + else: + raise FileNotFoundError(f"Could not find files {[str(path / file) for file in missing_files]}") + return _load_mnist_like(path, _fashion_mnist_files, _fashion_mnist_labels) + + +def load_kmnist(path: str | Path, download: bool = True) -> tuple[ImageDataset[Column], ImageDataset[Column]]: + """ + Load the `Kuzushiji-MNIST `_ datasets. + + Parameters + ---------- + path: + the path were the files are stored or will be downloaded to + download: + whether the files should be downloaded to the given path + + Returns + ------- + train_dataset, test_dataset: + The train and test datasets. + + Raises + ------ + FileNotFoundError + if a file of the dataset cannot be found + """ + path = Path(path) / _kuzushiji_mnist_folder + path.mkdir(parents=True, exist_ok=True) + path_files = os.listdir(path) + missing_files = [] + for file_path in _kuzushiji_mnist_files.values(): + if file_path not in path_files: + missing_files.append(file_path) + if len(missing_files) > 0: + if download: + _download_mnist_like( + path, + {name: f_path for name, f_path in _kuzushiji_mnist_files.items() if f_path in missing_files}, + _kuzushiji_mnist_links, + ) + else: + raise FileNotFoundError(f"Could not find files {[str(path / file) for file in missing_files]}") + return _load_mnist_like(path, _kuzushiji_mnist_files, _kuzushiji_mnist_labels) + + +def _load_mnist_like( + path: str | Path, + files: dict[str, str], + labels: dict[int, str], +) -> tuple[ImageDataset[Column], ImageDataset[Column]]: + _init_default_device() + + path = Path(path) + test_labels: Column | None = None + train_labels: Column | None = None + test_image_list: ImageList | None = None + train_image_list: ImageList | None = None + for file_name, file_path in files.items(): + if "idx1" in file_name: + with gzip.open(path / file_path, mode="rb") as label_file: + magic, size = struct.unpack(">II", label_file.read(8)) + if magic != 2049: + raise ValueError(f"Magic number mismatch. Actual {magic} != Expected 2049.") # pragma: no cover + if "train" in file_name: + train_labels = Column( + file_name, + [labels[label_index] for label_index in array("B", label_file.read())], + ) + else: + test_labels = Column( + file_name, + [labels[label_index] for label_index in array("B", label_file.read())], + ) + else: + with gzip.open(path / file_path, mode="rb") as image_file: + magic, size, rows, cols = struct.unpack(">IIII", image_file.read(16)) + if magic != 2051: + raise ValueError(f"Magic number mismatch. Actual {magic} != Expected 2051.") # pragma: no cover + image_data = array("B", image_file.read()) + image_tensor = torch.empty(size, 1, rows, cols, dtype=torch.uint8) + for i in range(size): + image_tensor[i, 0] = torch.frombuffer( + image_data[i * rows * cols : (i + 1) * rows * cols], + dtype=torch.uint8, + ).reshape(rows, cols) + image_list = _SingleSizeImageList() + image_list._tensor = image_tensor + image_list._tensor_positions_to_indices = list(range(size)) + image_list._indices_to_tensor_positions = image_list._calc_new_indices_to_tensor_positions() + if "train" in file_name: + train_image_list = image_list + else: + test_image_list = image_list + if train_image_list is None or test_image_list is None or train_labels is None or test_labels is None: + raise ValueError # pragma: no cover + return ImageDataset[Column](train_image_list, train_labels, 32, shuffle=True), ImageDataset[Column]( + test_image_list, + test_labels, + 32, + ) + + +def _download_mnist_like(path: str | Path, files: dict[str, str], links: list[str]) -> None: + path = Path(path) + for file_name, file_path in files.items(): + for link in links: + try: + print(f"Trying to download file {file_name} via {link + file_path}") # noqa: T201 + urllib.request.urlretrieve(link + file_path, path / file_path, reporthook=_report_download_progress) + print() # noqa: T201 + break + except HTTPError as e: + print(f"An error occurred while downloading: {e}") # noqa: T201 # pragma: no cover + + +def _report_download_progress(current_packages: int, package_size: int, file_size: int) -> None: + percentage = min(((current_packages * package_size) / file_size) * 100, 100) + sys.stdout.write(f"\rDownloading... {percentage:.1f}%") + sys.stdout.flush() diff --git a/tests/safeds_datasets/image/__init__.py b/tests/safeds_datasets/image/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/safeds_datasets/image/_mnist/__init__.py b/tests/safeds_datasets/image/_mnist/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/safeds_datasets/image/_mnist/test_mnist.py b/tests/safeds_datasets/image/_mnist/test_mnist.py new file mode 100644 index 0000000..5ada3db --- /dev/null +++ b/tests/safeds_datasets/image/_mnist/test_mnist.py @@ -0,0 +1,98 @@ +import os +import tempfile +from pathlib import Path + +import pytest +import torch +from safeds.data.labeled.containers import ImageDataset +from safeds_datasets.image import _mnist, load_fashion_mnist, load_kmnist, load_mnist + + +class TestMNIST: + + def test_should_download_and_return_mnist(self) -> None: + with tempfile.TemporaryDirectory() as tmpdirname: + train, test = load_mnist(tmpdirname, download=True) + files = os.listdir(Path(tmpdirname) / _mnist._mnist._mnist_folder) + for mnist_file in _mnist._mnist._mnist_files.values(): + assert mnist_file in files + assert isinstance(train, ImageDataset) + assert isinstance(test, ImageDataset) + assert len(train) == 60_000 + assert len(test) == 10_000 + assert ( + train.get_input()._as_single_size_image_list()._tensor.dtype + == test.get_input()._as_single_size_image_list()._tensor.dtype + == torch.uint8 + ) + train_output = train.get_output() + test_output = test.get_output() + assert ( + set(train_output.get_distinct_values()) + == set(test_output.get_distinct_values()) + == set(_mnist._mnist._mnist_labels.values()) + ) + + def test_should_raise_if_file_not_found(self) -> None: + with tempfile.TemporaryDirectory() as tmpdirname, pytest.raises(FileNotFoundError): + load_mnist(tmpdirname, download=False) + + +class TestFashionMNIST: + + def test_should_download_and_return_mnist(self) -> None: + with tempfile.TemporaryDirectory() as tmpdirname: + train, test = load_fashion_mnist(tmpdirname, download=True) + files = os.listdir(Path(tmpdirname) / _mnist._mnist._fashion_mnist_folder) + for mnist_file in _mnist._mnist._fashion_mnist_files.values(): + assert mnist_file in files + assert isinstance(train, ImageDataset) + assert isinstance(test, ImageDataset) + assert len(train) == 60_000 + assert len(test) == 10_000 + assert ( + train.get_input()._as_single_size_image_list()._tensor.dtype + == test.get_input()._as_single_size_image_list()._tensor.dtype + == torch.uint8 + ) + train_output = train.get_output() + test_output = test.get_output() + assert ( + set(train_output.get_distinct_values()) + == set(test_output.get_distinct_values()) + == set(_mnist._mnist._fashion_mnist_labels.values()) + ) + + def test_should_raise_if_file_not_found(self) -> None: + with tempfile.TemporaryDirectory() as tmpdirname, pytest.raises(FileNotFoundError): + load_fashion_mnist(tmpdirname, download=False) + + +class TestKMNIST: + + def test_should_download_and_return_mnist(self) -> None: + with tempfile.TemporaryDirectory() as tmpdirname: + train, test = load_kmnist(tmpdirname, download=True) + files = os.listdir(Path(tmpdirname) / _mnist._mnist._kuzushiji_mnist_folder) + for mnist_file in _mnist._mnist._kuzushiji_mnist_files.values(): + assert mnist_file in files + assert isinstance(train, ImageDataset) + assert isinstance(test, ImageDataset) + assert len(train) == 60_000 + assert len(test) == 10_000 + assert ( + train.get_input()._as_single_size_image_list()._tensor.dtype + == test.get_input()._as_single_size_image_list()._tensor.dtype + == torch.uint8 + ) + train_output = train.get_output() + test_output = test.get_output() + assert ( + set(train_output.get_distinct_values()) + == set(test_output.get_distinct_values()) + == set(_mnist._mnist._kuzushiji_mnist_labels.values()) + ) + + def test_should_raise_if_file_not_found(self) -> None: + with tempfile.TemporaryDirectory() as tmpdirname, pytest.raises(FileNotFoundError): + load_kmnist(tmpdirname, download=False)