diff --git a/docs/index.rst b/docs/index.rst index e91379d6f3..b9693ec905 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -31,6 +31,7 @@ torchgeo tutorials/getting_started tutorials/pytorch tutorials/geospatial + tutorials/contribute_non_geo_dataset tutorials/custom_raster_dataset tutorials/transforms tutorials/indices diff --git a/docs/tutorials/contribute_non_geo_dataset.ipynb b/docs/tutorials/contribute_non_geo_dataset.ipynb new file mode 100644 index 0000000000..655deb1c8d --- /dev/null +++ b/docs/tutorials/contribute_non_geo_dataset.ipynb @@ -0,0 +1,485 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Copyright (c) Microsoft Corporation. All rights reserved.\n", + "# Licensed under the MIT License." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Contribute a New Non-Geospatial Dataset\n", + "\n", + "_Written by: Nils Lehmann_\n", + "\n", + "Open-source datasets have significantly accelerated machine learning research. Geospatial machine learning datasets can be particularly complex to work with compared to more standard RGB-based vision datasets. To spare the community from having to repeatly implement data loading logic over and over, TorchGeo provides dozens of built-in datasets such that they can be downloaded and ready for use in a PyTorch framework with a single line of code. This tutorial will show how you can add a new non-geospatial dataset to this growing collection. \n", + "\n", + "As a reminder, TorchGeo differentiates between two types of datasets: geospatial and non-geospatial datasets. Non-geospatial datasets are integer indexed, like the datasets one might be familar with from torchvision, while geospatial datasets are indexed via spatiotemporal bounding boxes. Non-geospatial datasets can still return geospatial and other metadata and should be specific to the remote sensing domain. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup\n", + "\n", + "First, we install TorchGeo and its dependencies." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install torchgeo" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Where to start\n", + "\n", + "There are many types of remote sensing datasets. [Satellite-Image-Deep-Learning](https://github.com/satellite-image-deep-learning/datasets) maintains a list of many of these datasets, as well as links to other similar curated lists.\n", + "\n", + "Two aspects that will make it a lot easier to add the dataset are whether or not the dataset can be easily downloaded and whether or the dataset comes with a Github repository and publication that outlines how the authors intend the dataset to be used. These are not necessariy criteria, and sometimes it might be even more worthwhile to add a dataset without an existing code base, precisely because the marginal contribution to the community might be greater since a use of the dataset does not necessitate writing the loading implementation from scratch." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Adding the dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Once you have identified a dataset that you would like to add to TorchGeo, you could identify in what application category it might roughly fall in. For example, a segmentation dataset based on a collection of *.png* files, versus a classification dataset based on pre-defined image chips in *.tif* files. In the later case, if you find that the dataset contains *.tif* files that have very large pixel sizes, such that loading a single file might be costly, consider adding the dataset as a geospatial dataset for easier indexing. Once, you have identified the \"task\" such as segmentation vs classification and the dataset format, see whether a dataset of the same or similar category exists in TorchGeo already. All datasets inherit from a `NonGeoDataset` or `GeoDataset` base class that provides an outline for the implementation logic as well as additional utility functions that should be reused. This reduces code duplication and makes it easier to unit test datasets.\n", + "\n", + "Adding a dataset to TorchGeo consists of roughly four steps:\n", + "\n", + "1. a `dataset_name.py` file itself that implements the logic of the dataset\n", + "2. a `data.py` file that creates dummy data in the same structure and format as the original dataset for unit tests\n", + "3. a `test_dataset_name.py` file that implements unit tests for the dataset\n", + "4. an entry to the documentation page files: `non_geo_datasets.csv` and `datasets.rst`" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## The `dataset_name.py` file" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This file implements the logic to load a sample from the dataset as well as downloading the dataset automatically if possible. The new dataset inherits from a base class and the documentation string (docstring) of the class should contain:\n", + "\n", + "* a short summary of the dataset\n", + "* outline the features, such as the task the dataset is designed to solve\n", + "* outline the format the dataset comes in, e.g., file types, pixel dimensions, etc.\n", + "* a proper reference to the dataset such as a link to the paper so users can adequately cite the dataset when using it\n", + "* if required, a note about additional dependencies that are not part of TorchGeo's required dependencies\n", + "\n", + "The dataset implementation itself should contain:\n", + "\n", + "* a method to create an index structure the dataset can iterate over to load samples. This index structure also defines the length (`__len__`) of the dataset, i.e. how many individual samples can be loaded from the dataset\n", + "* a `__getitem__` method that takes an integer index argument, loads a sample of the dataset, and returns its components in a dictionary\n", + "* a `_verify` method that checks whether the dataset can be found on the filesystem, has already been downloaded and only needs to be extracted, or downloads and extracts the dataset from the web\n", + "* a `plot` method that can visually display a single sample of the dataset\n", + "\n", + "The code below attempts to roughly outline the parts required for a new `NonGeoDataset`. Specifics are of course very dependent on the type of dataset you want to add, but this template and other existing datasets should give you a decent starting point." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from collections.abc import Callable\n", + "\n", + "from matplotlib.pyplot import Figure\n", + "from torch import Tensor\n", + "\n", + "from torchgeo.datasets import NonGeoDataset\n", + "from torchgeo.datasets.utils import Path\n", + "\n", + "\n", + "class MyNewDataset(NonGeoDataset):\n", + " \"\"\"MyNewDataset.\n", + "\n", + " Short summary of the dataset and link to its homepage.\n", + "\n", + " Dataset features:\n", + "\n", + " * number of classes\n", + " * sensors\n", + " * area covered\n", + " * etc.\n", + "\n", + " Dataset format:\n", + "\n", + " * what file format and shape the input data comes in\n", + " * what file format and shape the target data comes in\n", + " * possible metadata files\n", + "\n", + " If you use this dataset in your research, please cite the following paper:\n", + "\n", + " * URL of publication or citation information\n", + "\n", + " .. versionadded: next TorchGeo minor release version, e.g., 1.0\n", + " \"\"\"\n", + "\n", + " # In this part of the code you can define class attributes such as a list of\n", + " # class names, color maps, url and checksums for data download, and other\n", + " # attributes that one might require repeatedly in the subsequent class methods.\n", + "\n", + " def __init__(\n", + " self,\n", + " root: Path = 'data',\n", + " split: str = 'train',\n", + " transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,\n", + " download: bool = False,\n", + " ) -> None:\n", + " \"\"\"Initialize the dataset.\n", + "\n", + " The init parameters can include additional arguments, such as an option to\n", + " select specific image bands, data modalities, or other arguments that give\n", + " greater control over data loading. They should all have reasonable defaults.\n", + "\n", + " Args:\n", + " root: root directory where dataset can be found\n", + " split: one of \"train\", \"val\", or \"test\"\n", + " transforms: a function/transform that takes input sample and its target as\n", + " entry and returns a transformed version\n", + " download: if True, download dataset and store it in the root directory\n", + " \"\"\"\n", + "\n", + " def __len__(self) -> int:\n", + " \"\"\"The length of the dataset.\n", + "\n", + " This is the total number of samples per epoch, and is used to define the\n", + " maximum allow index that can be passed to `__getitem__`.\n", + " \"\"\"\n", + "\n", + " def __getitem__(self, index: int) -> dict[str, Tensor]:\n", + " \"\"\"A single sample from the dataset.\n", + "\n", + " Load a single input image and target label or mask, and return it in a\n", + " dictionary.\n", + " \"\"\"\n", + "\n", + " def plot(self) -> Figure:\n", + " \"\"\"Plot a sample of the dataset for visualization purposes.\n", + "\n", + " This might involve selecting the RGB bands, using a colormap to display a mask,\n", + " adding a legend with class labels, etc.\n", + " \"\"\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## The `data.py` file" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `data.py` file is placed under `tests/data/dataset_name/` directory and creates a smaller dummy dataset that replicates the features and formats of the actual full datasets for unit tests. This is needed to keep the tests fast (we don't have time or storage space to download the real dataset) and to comply with the dataset license. \n", + "\n", + "The script should:\n", + "\n", + "* replicate the directory structure and file names\n", + "* replicate the file format, data type, and range of values\n", + "* use the same compression scheme to simulate downloading the dataset\n", + "\n", + "This is usually highly dependent on the dataset format and structure the new dataset comes in. You should always look for a similar dataset first and use that as a reference. However, below is an outline of the usual building blocks of a `data.py` script, for example an image segmentation dataset with 10 classes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import shutil\n", + "\n", + "import numpy as np\n", + "from PIL import Image\n", + "\n", + "# Define the root directory and subdirectories\n", + "root_dir = 'my_new_dataset'\n", + "sub_dirs = ['sub_dir_1', 'sub_dir_2', 'sub_dir_3']\n", + "splits = ['train', 'val', 'test']\n", + "\n", + "image_file_names = ['sample_1.png', 'sample_2.png', 'sample_3.png']\n", + "\n", + "IMG_SIZE = 32\n", + "\n", + "\n", + "# Function to create dummy input images\n", + "def create_input_image(path: str, shape: tuple[int], pixel_values: list[int]) -> None:\n", + " data = np.random.choice(pixel_values, size=shape, replace=True).astype(np.uint8)\n", + " img = Image.fromarray(data)\n", + " img.save(path)\n", + "\n", + "\n", + "# Function to create dummy targets\n", + "def create_target_images(split: str, filename: str) -> None:\n", + " target_pixel_values = range(10)\n", + " path = os.path.join(root_dir, 'target', split, filename)\n", + " create_input_image(path, (IMG_SIZE, IMG_SIZE), target_pixel_values)\n", + "\n", + "\n", + "# Create a new clean version when re-running the script\n", + "if os.path.exists(root_dir):\n", + " shutil.rmtree(root_dir)\n", + "\n", + "# Create the directory structure\n", + "for sub_dir in sub_dirs:\n", + " for split in splits:\n", + " os.makedirs(os.path.join(root_dir, sub_dir, split), exist_ok=True)\n", + "\n", + "# Create dummy data for all splits and filenames\n", + "for split in splits:\n", + " for filename in image_file_names:\n", + " create_input_image(split, filename)\n", + " create_target_images(split, filename.replace('_', '_target_'))\n", + "\n", + "# Zip directory\n", + "shutil.make_archive(root_dir, 'zip', '.', root_dir)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## The `test_dataset_name.py` file" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `test_dataset_name.py` file is placed under the `tests/datasets/` directory. This file implements the unit tests for the dataset, such that every line of code in `dataset_name.py` is tested. The logic of the individual test cases will likely be very similar to existing test files so you can look at those to to see how you can test the individual parts of the dataset logic." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import shutil\n", + "from pathlib import Path\n", + "\n", + "import pytest\n", + "import torch\n", + "import torch.nn as nn\n", + "from _pytest.fixtures import SubRequest\n", + "from matplotlib import pyplot as plt\n", + "from pytest import MonkeyPatch\n", + "\n", + "from torchgeo.datasets import DatasetNotFoundError\n", + "\n", + "\n", + "def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:\n", + " shutil.copy(url, root)\n", + "\n", + "\n", + "class TestMyNewDataset:\n", + " # pytest fixtures can be used to define variables to test different argument\n", + " # configurations to test, for example the different splits of the dataset\n", + " # or subselection of modalities/bands\n", + " @pytest.fixture(params=['train', 'val', 'test'])\n", + " def dataset(\n", + " self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest\n", + " ) -> MyNewDataset:\n", + " # monkeypatch can overwrite the class attributes defined above the __init__\n", + " # method and use the specific unit tests settings to mock behavior\n", + "\n", + " split: str = request.param\n", + " transforms = nn.Identity()\n", + " return MyNewDataset(tmp_path, split=split, transforms=transforms, download=True)\n", + "\n", + " def test_getitem(self, dataset: MyNewDataset) -> None:\n", + " # Retrieve a sample and check some of the desired properties\n", + " x = dataset[0]\n", + " assert isinstance(x, dict)\n", + " assert isinstance(x['image'], torch.Tensor)\n", + " assert isinstance(x['label'], torch.Tensor)\n", + "\n", + " # For all additional class arguments, check behavior for invalid parameters\n", + " def test_invalid_split(self) -> None:\n", + " with pytest.raises(AssertionError):\n", + " MyNewDataset(foo='bar')\n", + "\n", + " # Test the length of the dataset, this should coincide with the dummy data\n", + " def test_len(self, dataset: MyNewDataset) -> None:\n", + " assert len(dataset) == 2\n", + "\n", + " # Test the logic when the dataset is already downloaded\n", + " def test_already_downloaded(self, dataset: MyNewDataset, tmp_path: Path) -> None:\n", + " MyNewDataset(root=tmp_path, download=True)\n", + "\n", + " # Test the logic when the dataset is already downloaded but not extracted\n", + " def test_already_downloaded_not_extracted(\n", + " self, dataset: MyNewDataset, tmp_path: Path\n", + " ) -> None:\n", + " shutil.rmtree(dataset.root)\n", + " download_url(dataset.url, root=tmp_path)\n", + " MyNewDataset(root=tmp_path, download=False)\n", + "\n", + " # Test the logic when the dataset is not downloaded\n", + " def test_not_downloaded(self, tmp_path: Path) -> None:\n", + " with pytest.raises(DatasetNotFoundError, match='Dataset not found'):\n", + " MyNewDataset(tmp_path)\n", + "\n", + " # Test the plotting method through something like the following\n", + " def test_plot(self, dataset: MyNewDataset) -> None:\n", + " x = dataset[0].copy()\n", + " x['prediction'] = x['label'].clone()\n", + " dataset.plot(x, suptitle='Test')\n", + " plt.close()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Documentation Entries" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The entry point for new and experienced users of domain libraries is often the dedicated documentation page that accompanies a Github repository. TorchGeo uses the popular [Sphinx](https://www.sphinx-doc.org/en/master/) framework to build its documentation. To display the documentation strings you have written in `dataset_name.py` on the actual documentation page, you need to create an entry in `docs/api/datasets.rst` in alphabetical order:" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "```rst\n", + "Dataset Name\n", + "^^^^^^^^^^^^\n", + "\n", + ".. autoclass:: MyNewDataset\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Additionally, add a row in the `non_geo_datasets.csv` file under `docs/api/datasets` to include the dataset in the overview table." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Linters\n", + "\n", + "See the [linter docs](https://torchgeo.readthedocs.io/en/stable/user/contributing.html#linters) for an overview of linters that TorchGeo employs and how to apply them during commits." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Test Coverage" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "TorchGeo maintains a test coverage of 100%. This means, that every line of code written within the torchgeo directory is being run by some unit test. The [testing docs](https://torchgeo.readthedocs.io/en/stable/user/contributing.html#tests) provide instructions on how you can test the coverage locally for the `dataset_new.py` file that you are adding." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Final Checklist" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This final checklist might provide a useful overview of the individual parts discussed in this tutorial. You definitely do not need to check all boxes, before submitting a PR. If you have any questions feel free to ask on Slack or open a PR already such that maintainers or other community members can answer specific questions or give pointers. If you want to run your PR as a work of progress, such that the CI tests are run against your code while you work on ticking more boxes you can also convert the PR to a draft on the right side menu." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- Dataset implementation in `dataset_name.py`\n", + " - Class docstring containing:\n", + " - Summary intro\n", + " - Dataset features\n", + " - Dataset format\n", + " - Link to publication\n", + " - `versionadded` tag\n", + " - if applicable a note on additional dependencies\n", + " - all class methods have docstrings\n", + " - all class methods have argument and return type hints, mypy (the tool that checks type hints) can be confusing at the beginning so don't hesitate to ask for help\n", + " - if dataset is on GitHub or Huggingface, url link should contain the commit hash\n", + " - checksum added\n", + " - plot method that can display a single sample from the dataset (you can add the resulting figure in your PR description)\n", + " - add the dataset to `torchgeo/datastes/__init__.py`\n", + " - Add the copyright at the top of the file\n", + "- Dummy data script `data.py`\n", + " - replicate directory structure\n", + " - replicate naming of directory and files\n", + " - for image based datasets, use a small size, like 32x32\n", + "- Unit tests `test_dataset_name.py`\n", + " - 100% test coverage \n", + "- Documentation with `non_geo_datasets.csv` and `datasets.rst`\n", + " - entry in `datasets.rst`\n", + " - entry in `non_geo_datasets.csv`\n", + " - documentation displays properly, this can be checked locally or via the GitHub CI tests under `docs/readthedocs.org:torchgeo`" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}