This repository has been archived by the owner on Nov 23, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: added MNIST, Fashion-MNIST and KMNIST datasets (#164)
Closes #161 Closes #162 Closes #163 ### Summary of Changes feat: added MNIST, Fashion-MNIST and KMNIST datasets build: bump safe-ds to ^0.24.0 --------- Co-authored-by: megalinter-bot <[email protected]> Co-authored-by: Lars Reimann <[email protected]>
- Loading branch information
1 parent
90de957
commit 97ae47a
Showing
8 changed files
with
366 additions
and
4 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
"""Image datasets.""" | ||
|
||
from ._mnist import load_fashion_mnist, load_kmnist, load_mnist | ||
|
||
__all__ = ["load_fashion_mnist", "load_kmnist", "load_mnist"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
"""MNIST like Datasets.""" | ||
|
||
from ._mnist import load_fashion_mnist, load_kmnist, load_mnist | ||
|
||
__all__ = ["load_fashion_mnist", "load_kmnist", "load_mnist"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,256 @@ | ||
import gzip | ||
import os | ||
import struct | ||
import sys | ||
import urllib.request | ||
from array import array | ||
from pathlib import Path | ||
from typing import TYPE_CHECKING | ||
from urllib.error import HTTPError | ||
|
||
import torch | ||
from safeds._config import _init_default_device | ||
from safeds.data.image.containers._single_size_image_list import _SingleSizeImageList | ||
from safeds.data.labeled.containers import ImageDataset | ||
from safeds.data.tabular.containers import Column | ||
|
||
if TYPE_CHECKING: | ||
from safeds.data.image.containers import ImageList | ||
|
||
_mnist_links: list[str] = ["http://yann.lecun.com/exdb/mnist/", "https://ossci-datasets.s3.amazonaws.com/mnist/"] | ||
_mnist_files: dict[str, str] = { | ||
"train-images-idx3": "train-images-idx3-ubyte.gz", | ||
"train-labels-idx1": "train-labels-idx1-ubyte.gz", | ||
"test-images-idx3": "t10k-images-idx3-ubyte.gz", | ||
"test-labels-idx1": "t10k-labels-idx1-ubyte.gz", | ||
} | ||
_mnist_labels: dict[int, str] = {0: "0", 1: "1", 2: "2", 3: "3", 4: "4", 5: "5", 6: "6", 7: "7", 8: "8", 9: "9"} | ||
_mnist_folder: str = "mnist" | ||
|
||
_fashion_mnist_links: list[str] = ["http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/"] | ||
_fashion_mnist_files: dict[str, str] = _mnist_files | ||
_fashion_mnist_labels: dict[int, str] = { | ||
0: "T-shirt/top", | ||
1: "Trouser", | ||
2: "Pullover", | ||
3: "Dress", | ||
4: "Coat", | ||
5: "Sandal", | ||
6: "Shirt", | ||
7: "Sneaker", | ||
8: "Bag", | ||
9: "Ankle boot", | ||
} | ||
_fashion_mnist_folder: str = "fashion-mnist" | ||
|
||
_kuzushiji_mnist_links: list[str] = ["http://codh.rois.ac.jp/kmnist/dataset/kmnist/"] | ||
_kuzushiji_mnist_files: dict[str, str] = _mnist_files | ||
_kuzushiji_mnist_labels: dict[int, str] = { | ||
0: "\u304a", | ||
1: "\u304d", | ||
2: "\u3059", | ||
3: "\u3064", | ||
4: "\u306a", | ||
5: "\u306f", | ||
6: "\u307e", | ||
7: "\u3084", | ||
8: "\u308c", | ||
9: "\u3092", | ||
} | ||
_kuzushiji_mnist_folder: str = "kmnist" | ||
|
||
|
||
def load_mnist(path: str | Path, download: bool = True) -> tuple[ImageDataset[Column], ImageDataset[Column]]: | ||
""" | ||
Load the `MNIST <http://yann.lecun.com/exdb/mnist/>`_ datasets. | ||
Parameters | ||
---------- | ||
path: | ||
the path were the files are stored or will be downloaded to | ||
download: | ||
whether the files should be downloaded to the given path | ||
Returns | ||
------- | ||
train_dataset, test_dataset: | ||
The train and test datasets. | ||
Raises | ||
------ | ||
FileNotFoundError | ||
if a file of the dataset cannot be found | ||
""" | ||
path = Path(path) / _mnist_folder | ||
path.mkdir(parents=True, exist_ok=True) | ||
path_files = os.listdir(path) | ||
missing_files = [] | ||
for file_path in _mnist_files.values(): | ||
if file_path not in path_files: | ||
missing_files.append(file_path) | ||
if len(missing_files) > 0: | ||
if download: | ||
_download_mnist_like( | ||
path, | ||
{name: f_path for name, f_path in _mnist_files.items() if f_path in missing_files}, | ||
_mnist_links, | ||
) | ||
else: | ||
raise FileNotFoundError(f"Could not find files {[str(path / file) for file in missing_files]}") | ||
return _load_mnist_like(path, _mnist_files, _mnist_labels) | ||
|
||
|
||
def load_fashion_mnist(path: str | Path, download: bool = True) -> tuple[ImageDataset[Column], ImageDataset[Column]]: | ||
""" | ||
Load the `Fashion-MNIST <https://github.com/zalandoresearch/fashion-mnist>`_ datasets. | ||
Parameters | ||
---------- | ||
path: | ||
the path were the files are stored or will be downloaded to | ||
download: | ||
whether the files should be downloaded to the given path | ||
Returns | ||
------- | ||
train_dataset, test_dataset: | ||
The train and test datasets. | ||
Raises | ||
------ | ||
FileNotFoundError | ||
if a file of the dataset cannot be found | ||
""" | ||
path = Path(path) / _fashion_mnist_folder | ||
path.mkdir(parents=True, exist_ok=True) | ||
path_files = os.listdir(path) | ||
missing_files = [] | ||
for file_path in _fashion_mnist_files.values(): | ||
if file_path not in path_files: | ||
missing_files.append(file_path) | ||
if len(missing_files) > 0: | ||
if download: | ||
_download_mnist_like( | ||
path, | ||
{name: f_path for name, f_path in _fashion_mnist_files.items() if f_path in missing_files}, | ||
_fashion_mnist_links, | ||
) | ||
else: | ||
raise FileNotFoundError(f"Could not find files {[str(path / file) for file in missing_files]}") | ||
return _load_mnist_like(path, _fashion_mnist_files, _fashion_mnist_labels) | ||
|
||
|
||
def load_kmnist(path: str | Path, download: bool = True) -> tuple[ImageDataset[Column], ImageDataset[Column]]: | ||
""" | ||
Load the `Kuzushiji-MNIST <https://github.com/rois-codh/kmnist>`_ datasets. | ||
Parameters | ||
---------- | ||
path: | ||
the path were the files are stored or will be downloaded to | ||
download: | ||
whether the files should be downloaded to the given path | ||
Returns | ||
------- | ||
train_dataset, test_dataset: | ||
The train and test datasets. | ||
Raises | ||
------ | ||
FileNotFoundError | ||
if a file of the dataset cannot be found | ||
""" | ||
path = Path(path) / _kuzushiji_mnist_folder | ||
path.mkdir(parents=True, exist_ok=True) | ||
path_files = os.listdir(path) | ||
missing_files = [] | ||
for file_path in _kuzushiji_mnist_files.values(): | ||
if file_path not in path_files: | ||
missing_files.append(file_path) | ||
if len(missing_files) > 0: | ||
if download: | ||
_download_mnist_like( | ||
path, | ||
{name: f_path for name, f_path in _kuzushiji_mnist_files.items() if f_path in missing_files}, | ||
_kuzushiji_mnist_links, | ||
) | ||
else: | ||
raise FileNotFoundError(f"Could not find files {[str(path / file) for file in missing_files]}") | ||
return _load_mnist_like(path, _kuzushiji_mnist_files, _kuzushiji_mnist_labels) | ||
|
||
|
||
def _load_mnist_like( | ||
path: str | Path, | ||
files: dict[str, str], | ||
labels: dict[int, str], | ||
) -> tuple[ImageDataset[Column], ImageDataset[Column]]: | ||
_init_default_device() | ||
|
||
path = Path(path) | ||
test_labels: Column | None = None | ||
train_labels: Column | None = None | ||
test_image_list: ImageList | None = None | ||
train_image_list: ImageList | None = None | ||
for file_name, file_path in files.items(): | ||
if "idx1" in file_name: | ||
with gzip.open(path / file_path, mode="rb") as label_file: | ||
magic, size = struct.unpack(">II", label_file.read(8)) | ||
if magic != 2049: | ||
raise ValueError(f"Magic number mismatch. Actual {magic} != Expected 2049.") # pragma: no cover | ||
if "train" in file_name: | ||
train_labels = Column( | ||
file_name, | ||
[labels[label_index] for label_index in array("B", label_file.read())], | ||
) | ||
else: | ||
test_labels = Column( | ||
file_name, | ||
[labels[label_index] for label_index in array("B", label_file.read())], | ||
) | ||
else: | ||
with gzip.open(path / file_path, mode="rb") as image_file: | ||
magic, size, rows, cols = struct.unpack(">IIII", image_file.read(16)) | ||
if magic != 2051: | ||
raise ValueError(f"Magic number mismatch. Actual {magic} != Expected 2051.") # pragma: no cover | ||
image_data = array("B", image_file.read()) | ||
image_tensor = torch.empty(size, 1, rows, cols, dtype=torch.uint8) | ||
for i in range(size): | ||
image_tensor[i, 0] = torch.frombuffer( | ||
image_data[i * rows * cols : (i + 1) * rows * cols], | ||
dtype=torch.uint8, | ||
).reshape(rows, cols) | ||
image_list = _SingleSizeImageList() | ||
image_list._tensor = image_tensor | ||
image_list._tensor_positions_to_indices = list(range(size)) | ||
image_list._indices_to_tensor_positions = image_list._calc_new_indices_to_tensor_positions() | ||
if "train" in file_name: | ||
train_image_list = image_list | ||
else: | ||
test_image_list = image_list | ||
if train_image_list is None or test_image_list is None or train_labels is None or test_labels is None: | ||
raise ValueError # pragma: no cover | ||
return ImageDataset[Column](train_image_list, train_labels, 32, shuffle=True), ImageDataset[Column]( | ||
test_image_list, | ||
test_labels, | ||
32, | ||
) | ||
|
||
|
||
def _download_mnist_like(path: str | Path, files: dict[str, str], links: list[str]) -> None: | ||
path = Path(path) | ||
for file_name, file_path in files.items(): | ||
for link in links: | ||
try: | ||
print(f"Trying to download file {file_name} via {link + file_path}") # noqa: T201 | ||
urllib.request.urlretrieve(link + file_path, path / file_path, reporthook=_report_download_progress) | ||
print() # noqa: T201 | ||
break | ||
except HTTPError as e: | ||
print(f"An error occurred while downloading: {e}") # noqa: T201 # pragma: no cover | ||
|
||
|
||
def _report_download_progress(current_packages: int, package_size: int, file_size: int) -> None: | ||
percentage = min(((current_packages * package_size) / file_size) * 100, 100) | ||
sys.stdout.write(f"\rDownloading... {percentage:.1f}%") | ||
sys.stdout.flush() |
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
import os | ||
import tempfile | ||
from pathlib import Path | ||
|
||
import pytest | ||
import torch | ||
from safeds.data.labeled.containers import ImageDataset | ||
from safeds_datasets.image import _mnist, load_fashion_mnist, load_kmnist, load_mnist | ||
|
||
|
||
class TestMNIST: | ||
|
||
def test_should_download_and_return_mnist(self) -> None: | ||
with tempfile.TemporaryDirectory() as tmpdirname: | ||
train, test = load_mnist(tmpdirname, download=True) | ||
files = os.listdir(Path(tmpdirname) / _mnist._mnist._mnist_folder) | ||
for mnist_file in _mnist._mnist._mnist_files.values(): | ||
assert mnist_file in files | ||
assert isinstance(train, ImageDataset) | ||
assert isinstance(test, ImageDataset) | ||
assert len(train) == 60_000 | ||
assert len(test) == 10_000 | ||
assert ( | ||
train.get_input()._as_single_size_image_list()._tensor.dtype | ||
== test.get_input()._as_single_size_image_list()._tensor.dtype | ||
== torch.uint8 | ||
) | ||
train_output = train.get_output() | ||
test_output = test.get_output() | ||
assert ( | ||
set(train_output.get_distinct_values()) | ||
== set(test_output.get_distinct_values()) | ||
== set(_mnist._mnist._mnist_labels.values()) | ||
) | ||
|
||
def test_should_raise_if_file_not_found(self) -> None: | ||
with tempfile.TemporaryDirectory() as tmpdirname, pytest.raises(FileNotFoundError): | ||
load_mnist(tmpdirname, download=False) | ||
|
||
|
||
class TestFashionMNIST: | ||
|
||
def test_should_download_and_return_mnist(self) -> None: | ||
with tempfile.TemporaryDirectory() as tmpdirname: | ||
train, test = load_fashion_mnist(tmpdirname, download=True) | ||
files = os.listdir(Path(tmpdirname) / _mnist._mnist._fashion_mnist_folder) | ||
for mnist_file in _mnist._mnist._fashion_mnist_files.values(): | ||
assert mnist_file in files | ||
assert isinstance(train, ImageDataset) | ||
assert isinstance(test, ImageDataset) | ||
assert len(train) == 60_000 | ||
assert len(test) == 10_000 | ||
assert ( | ||
train.get_input()._as_single_size_image_list()._tensor.dtype | ||
== test.get_input()._as_single_size_image_list()._tensor.dtype | ||
== torch.uint8 | ||
) | ||
train_output = train.get_output() | ||
test_output = test.get_output() | ||
assert ( | ||
set(train_output.get_distinct_values()) | ||
== set(test_output.get_distinct_values()) | ||
== set(_mnist._mnist._fashion_mnist_labels.values()) | ||
) | ||
|
||
def test_should_raise_if_file_not_found(self) -> None: | ||
with tempfile.TemporaryDirectory() as tmpdirname, pytest.raises(FileNotFoundError): | ||
load_fashion_mnist(tmpdirname, download=False) | ||
|
||
|
||
class TestKMNIST: | ||
|
||
def test_should_download_and_return_mnist(self) -> None: | ||
with tempfile.TemporaryDirectory() as tmpdirname: | ||
train, test = load_kmnist(tmpdirname, download=True) | ||
files = os.listdir(Path(tmpdirname) / _mnist._mnist._kuzushiji_mnist_folder) | ||
for mnist_file in _mnist._mnist._kuzushiji_mnist_files.values(): | ||
assert mnist_file in files | ||
assert isinstance(train, ImageDataset) | ||
assert isinstance(test, ImageDataset) | ||
assert len(train) == 60_000 | ||
assert len(test) == 10_000 | ||
assert ( | ||
train.get_input()._as_single_size_image_list()._tensor.dtype | ||
== test.get_input()._as_single_size_image_list()._tensor.dtype | ||
== torch.uint8 | ||
) | ||
train_output = train.get_output() | ||
test_output = test.get_output() | ||
assert ( | ||
set(train_output.get_distinct_values()) | ||
== set(test_output.get_distinct_values()) | ||
== set(_mnist._mnist._kuzushiji_mnist_labels.values()) | ||
) | ||
|
||
def test_should_raise_if_file_not_found(self) -> None: | ||
with tempfile.TemporaryDirectory() as tmpdirname, pytest.raises(FileNotFoundError): | ||
load_kmnist(tmpdirname, download=False) |