Skip to content

Commit

Permalink
Move DataModules to torchgeo.datamodules (microsoft#321)
Browse files Browse the repository at this point in the history
* Move DataModules to torchgeo.datamodules

* Clean up local imports
  • Loading branch information
adamjstewart authored Dec 24, 2021
1 parent 79c60af commit 5973b9b
Show file tree
Hide file tree
Showing 100 changed files with 3,978 additions and 3,500 deletions.
105 changes: 105 additions & 0 deletions docs/api/datamodules.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
torchgeo.datamodules
====================

.. module:: torchgeo.datamodules

Geospatial DataModules
----------------------

Chesapeake Bay High-Resolution Land Cover Project
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: ChesapeakeCVPRDataModule

National Agriculture Imagery Program (NAIP)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: NAIPChesapeakeDataModule

Non-geospatial DataModules
--------------------------

BigEarthNet
^^^^^^^^^^^

.. autoclass:: BigEarthNetDataModule

Cars Overhead With Context (COWC)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: COWCCountingDataModule

ETCI2021 Flood Detection
^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: ETCI2021DataModule

EuroSAT
^^^^^^^

.. autoclass:: EuroSATDataModule

FAIR1M (Fine-grAined object recognItion in high-Resolution imagery)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: FAIR1MDataModule

LandCover.ai (Land Cover from Aerial Imagery)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: LandCoverAIDataModule

LoveDA (Land-cOVEr Domain Adaptive semantic segmentation)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: LoveDADataModule

NASA Marine Debris
^^^^^^^^^^^^^^^^^^

.. autoclass:: NASAMarineDebrisDataModule

OSCD (Onera Satellite Change Detection)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: OSCDDataModule

Potsdam
^^^^^^^

.. autoclass:: Potsdam2DDataModule

RESISC45 (Remote Sensing Image Scene Classification)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: RESISC45DataModule

SEN12MS
^^^^^^^

.. autoclass:: SEN12MSDataModule

So2Sat
^^^^^^

.. autoclass:: So2SatDataModule

Tropical Cyclone Wind Estimation Competition
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: CycloneDataModule

UC Merced
^^^^^^^^^

.. autoclass:: UCMercedDataModule

Vaihingen
^^^^^^^^^

.. autoclass:: Vaihingen2DDataModule

xView2
^^^^^^

.. autoclass:: XView2DataModule
31 changes: 6 additions & 25 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ Chesapeake Bay High-Resolution Land Cover Project
.. autoclass:: ChesapeakeVA
.. autoclass:: ChesapeakeWV
.. autoclass:: ChesapeakeCVPR
.. autoclass:: ChesapeakeCVPRDataModule

Cropland Data Layer (CDL)
^^^^^^^^^^^^^^^^^^^^^^^^^
Expand All @@ -57,7 +56,6 @@ National Agriculture Imagery Program (NAIP)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: NAIP
.. autoclass:: NAIPChesapeakeDataModule

Sentinel
^^^^^^^^
Expand Down Expand Up @@ -86,15 +84,13 @@ BigEarthNet
^^^^^^^^^^^

.. autoclass:: BigEarthNet
.. autoclass:: BigEarthNetDataModule

Cars Overhead With Context (COWC)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: COWC
.. autoclass:: COWCCounting
.. autoclass:: COWCDetection
.. autoclass:: COWCCountingDataModule

CV4A Kenya Crop Type Competition
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand All @@ -105,19 +101,16 @@ ETCI2021 Flood Detection
^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: ETCI2021
.. autoclass:: ETCI2021DataModule

EuroSAT
^^^^^^^^^^^^^^^^^^^^^^^^
^^^^^^^

.. autoclass:: EuroSAT
.. autoclass:: EuroSATDataModule

FAIR1M (Fine-grAined object recognItion in high-Resolution imagery)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: FAIR1M
.. autoclass:: FAIR1MDataModule

GID-15 (Gaofen Image Dataset)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand All @@ -133,7 +126,6 @@ LandCover.ai (Land Cover from Aerial Imagery)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: LandCoverAI
.. autoclass:: LandCoverAIDataModule

LEVIR-CD+ (LEVIR Change Detection +)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand All @@ -144,19 +136,16 @@ LoveDA (Land-cOVEr Domain Adaptive semantic segmentation)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: LoveDA
.. autoclass:: LoveDADataModule

NASA Marine Debris
^^^^^^^^^^^^^^^^^^

.. autoclass:: NASAMarineDebris
.. autoclass:: NASAMarineDebrisDataModule

OSCD (Onera Satellite Change Detection)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: OSCD
.. autoclass:: OSCDDataModule

PatternNet
^^^^^^^^^^
Expand All @@ -167,13 +156,11 @@ Potsdam
^^^^^^^

.. autoclass:: Potsdam2D
.. autoclass:: Potsdam2DDataModule

RESISC45 (Remote Sensing Image Scene Classification)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: RESISC45
.. autoclass:: RESISC45DataModule

Seasonal Contrast
^^^^^^^^^^^^^^^^^
Expand All @@ -184,13 +171,11 @@ SEN12MS
^^^^^^^

.. autoclass:: SEN12MS
.. autoclass:: SEN12MSDataModule

So2Sat
^^^^^^

.. autoclass:: So2Sat
.. autoclass:: So2SatDataModule

SpaceNet
^^^^^^^^
Expand All @@ -206,30 +191,26 @@ Tropical Cyclone Wind Estimation Competition
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: TropicalCycloneWindEstimation
.. autoclass:: CycloneDataModule

UC Merced
^^^^^^^^^

.. autoclass:: UCMerced

Vaihingen
^^^^^^^^^

.. autoclass:: Vaihingen2D
.. autoclass:: Vaihingen2DDataModule

NWPU VHR-10
^^^^^^^^^^^

.. autoclass:: VHR10

UC Merced
^^^^^^^^^

.. autoclass:: UCMerced
.. autoclass:: UCMercedDataModule

xView2
^^^^^^

.. autoclass:: XView2
.. autoclass:: XView2DataModule

ZueriCrop
^^^^^^^^^
Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ torchgeo
:maxdepth: 2
:caption: Package Reference

api/datamodules
api/datasets
api/models
api/samplers
Expand Down
2 changes: 1 addition & 1 deletion experiments/test_chesapeakecvpr_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import pytorch_lightning as pl
import torch

from torchgeo.datasets import ChesapeakeCVPRDataModule
from torchgeo.datamodules import ChesapeakeCVPRDataModule
from torchgeo.trainers.chesapeake import ChesapeakeCVPRSegmentationTask

ALL_TEST_SPLITS = [["de-val"], ["pa-test"], ["ny-test"], ["pa-test", "ny-test"]]
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ strict_equality = true

[tool.pydocstyle]
convention = "google"
match_dir = "(datasets|models|samplers|torchgeo|trainers|transforms)"
match_dir = "(datamodules|datasets|models|samplers|torchgeo|trainers|transforms)"

[tool.pytest.ini_options]
# Skip slow tests by default
Expand Down
2 changes: 2 additions & 0 deletions tests/datamodules/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
32 changes: 32 additions & 0 deletions tests/datamodules/test_bigearthnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os

import pytest
from _pytest.fixtures import SubRequest

from torchgeo.datamodules import BigEarthNetDataModule


class TestBigEarthNetDataModule:
@pytest.fixture(scope="class", params=["s1", "s2", "all"])
def datamodule(self, request: SubRequest) -> BigEarthNetDataModule:
bands = request.param
root = os.path.join("tests", "data", "bigearthnet")
num_classes = 19
batch_size = 1
num_workers = 0
dm = BigEarthNetDataModule(root, bands, num_classes, batch_size, num_workers)
dm.prepare_data()
dm.setup()
return dm

def test_train_dataloader(self, datamodule: BigEarthNetDataModule) -> None:
next(iter(datamodule.train_dataloader()))

def test_val_dataloader(self, datamodule: BigEarthNetDataModule) -> None:
next(iter(datamodule.val_dataloader()))

def test_test_dataloader(self, datamodule: BigEarthNetDataModule) -> None:
next(iter(datamodule.test_dataloader()))
52 changes: 52 additions & 0 deletions tests/datamodules/test_chesapeake.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os

import pytest
import torch
from _pytest.fixtures import SubRequest

from torchgeo.datamodules import ChesapeakeCVPRDataModule


class TestChesapeakeCVPRDataModule:
@pytest.fixture(scope="class", params=[5, 7])
def datamodule(self, request: SubRequest) -> ChesapeakeCVPRDataModule:
dm = ChesapeakeCVPRDataModule(
os.path.join("tests", "data", "chesapeake", "cvpr"),
["de-test"],
["de-test"],
["de-test"],
patch_size=32,
patches_per_tile=2,
batch_size=2,
num_workers=0,
class_set=request.param,
)
dm.prepare_data()
dm.setup()
return dm

def test_train_dataloader(self, datamodule: ChesapeakeCVPRDataModule) -> None:
next(iter(datamodule.train_dataloader()))

def test_val_dataloader(self, datamodule: ChesapeakeCVPRDataModule) -> None:
next(iter(datamodule.val_dataloader()))

def test_test_dataloader(self, datamodule: ChesapeakeCVPRDataModule) -> None:
next(iter(datamodule.test_dataloader()))

def test_nodata_check(self, datamodule: ChesapeakeCVPRDataModule) -> None:
nodata_check = datamodule.nodata_check(4)
sample = {
"image": torch.ones(1, 2, 2), # type: ignore[attr-defined]
"mask": torch.ones(2, 2), # type: ignore[attr-defined]
}
out = nodata_check(sample)
assert torch.equal( # type: ignore[attr-defined]
out["image"], torch.zeros(1, 4, 4) # type: ignore[attr-defined]
)
assert torch.equal( # type: ignore[attr-defined]
out["mask"], torch.zeros(4, 4) # type: ignore[attr-defined]
)
30 changes: 30 additions & 0 deletions tests/datamodules/test_cowc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os

import pytest

from torchgeo.datamodules import COWCCountingDataModule


class TestCOWCCountingDataModule:
@pytest.fixture(scope="class")
def datamodule(self) -> COWCCountingDataModule:
root = os.path.join("tests", "data", "cowc_counting")
seed = 0
batch_size = 1
num_workers = 0
dm = COWCCountingDataModule(root, seed, batch_size, num_workers)
dm.prepare_data()
dm.setup()
return dm

def test_train_dataloader(self, datamodule: COWCCountingDataModule) -> None:
next(iter(datamodule.train_dataloader()))

def test_val_dataloader(self, datamodule: COWCCountingDataModule) -> None:
next(iter(datamodule.val_dataloader()))

def test_test_dataloader(self, datamodule: COWCCountingDataModule) -> None:
next(iter(datamodule.test_dataloader()))
Loading

0 comments on commit 5973b9b

Please sign in to comment.